updates review touches ci fixes ci fix ci fixin ci fix review updates further cleanup updates updates update update commentpull/15125/head
| @@ -25,7 +25,7 @@ | |||
| #include "minddata/dataset/include/constants.h" | |||
| #include "minddata/dataset/core/global_context.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h" | |||
| #include "minddata/dataset/engine/db_connector.h" | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| #include "minddata/dataset/util/log_adapter.h" | |||
| @@ -67,9 +67,13 @@ Status MindRecordOp::Builder::Build(std::shared_ptr<MindRecordOp> *ptr) { | |||
| if (build_num_padded_ > 0) { | |||
| sample_json = ToJson(build_sample_); | |||
| } | |||
| new_mind_record_op = std::make_shared<MindRecordOp>( | |||
| build_num_mind_record_workers_, build_dataset_file_, build_load_dataset_, build_op_connector_queue_size_, | |||
| build_columns_to_load_, build_operators_, build_num_padded_, sample_json, build_sample_bytes_); | |||
| std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>(); | |||
| new_mind_record_op = | |||
| std::make_shared<MindRecordOp>(build_num_mind_record_workers_, build_dataset_file_, build_load_dataset_, | |||
| build_op_connector_queue_size_, build_columns_to_load_, build_operators_, | |||
| build_num_padded_, sample_json, build_sample_bytes_, std::move(shard_reader)); | |||
| RETURN_IF_NOT_OK(new_mind_record_op->Init()); | |||
| *ptr = std::move(new_mind_record_op); | |||
| @@ -110,8 +114,10 @@ mindrecord::json MindRecordOp::Builder::ToJson(const py::handle &obj) { | |||
| MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, std::vector<std::string> dataset_file, bool load_dataset, | |||
| int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load, | |||
| const std::vector<std::shared_ptr<ShardOperator>> &operators, int64_t num_padded, | |||
| const mindrecord::json &sample_json, const std::map<std::string, std::string> &sample_bytes) | |||
| : MappableLeafOp(num_mind_record_workers, op_connector_queue_size, std::make_shared<SequentialSamplerRT>(0, 0)), | |||
| const mindrecord::json &sample_json, const std::map<std::string, std::string> &sample_bytes, | |||
| std::unique_ptr<ShardReader> shard_reader) | |||
| : MappableLeafOp(num_mind_record_workers, op_connector_queue_size, | |||
| std::make_shared<MindRecordSamplerRT>(shard_reader.get())), | |||
| dataset_file_(dataset_file), | |||
| load_dataset_(load_dataset), | |||
| columns_to_load_(columns_to_load), | |||
| @@ -120,7 +126,8 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, std::vector<std::str | |||
| ended_worker_(0), | |||
| num_padded_(num_padded), | |||
| sample_json_(sample_json), | |||
| sample_bytes_(sample_bytes) { | |||
| sample_bytes_(sample_bytes), | |||
| shard_reader_(std::move(shard_reader)) { | |||
| io_block_queues_.Init(num_workers_, op_connector_queue_size); | |||
| epoch_sync_flag_ = true; // MindRecordOp needs to turn this flag on, otherwise, calling ShuffleTask() before all | |||
| // tasks are consumed by the worker threads would cause problem. | |||
| @@ -128,7 +135,6 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, std::vector<std::str | |||
| // Private helper method to encapsulate some common construction/reset tasks | |||
| Status MindRecordOp::Init() { | |||
| shard_reader_ = std::make_unique<ShardReader>(); | |||
| auto rc = shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, operators_, | |||
| num_padded_); | |||
| @@ -363,9 +369,6 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector<uint | |||
| Status MindRecordOp::Reset() { | |||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||
| RETURN_IF_NOT_OK(MappableLeafOp::Reset()); // Call our super class reset first. | |||
| shard_reader_->ShuffleTask(); | |||
| return Status::OK(); | |||
| } | |||
| @@ -140,7 +140,8 @@ class MindRecordOp : public MappableLeafOp { | |||
| MindRecordOp(int32_t num_mind_record_workers, std::vector<std::string> dataset_file, bool load_dataset, | |||
| int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load, | |||
| const std::vector<std::shared_ptr<ShardOperator>> &operators, int64_t num_padded_, | |||
| const mindrecord::json &sample_json, const std::map<std::string, std::string> &sample_bytes_); | |||
| const mindrecord::json &sample_json, const std::map<std::string, std::string> &sample_bytes_, | |||
| std::unique_ptr<ShardReader> shard_reader); | |||
| // Destructor | |||
| ~MindRecordOp() override; | |||
| @@ -10,6 +10,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES | |||
| subset_random_sampler.cc | |||
| subset_sampler.cc | |||
| weighted_random_sampler.cc | |||
| mind_record_sampler.cc | |||
| ) | |||
| if(ENABLE_PYTHON) | |||
| @@ -0,0 +1,86 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h" | |||
| #include <algorithm> | |||
| #include <limits> | |||
| #include <memory> | |||
| #include "minddata/mindrecord/include/shard_reader.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| MindRecordSamplerRT::MindRecordSamplerRT(mindrecord::ShardReader *shard_reader, int64_t samples_per_tensor) | |||
| : SamplerRT(0, samples_per_tensor), next_id_(0), shard_reader_(shard_reader) {} | |||
| Status MindRecordSamplerRT::GetNextSample(TensorRow *out) { | |||
| if (next_id_ > num_samples_) { | |||
| RETURN_STATUS_UNEXPECTED("MindRecordSampler Internal Error"); | |||
| } else if (next_id_ == num_samples_) { | |||
| (*out) = TensorRow(TensorRow::kFlagEOE); | |||
| } else { | |||
| std::shared_ptr<Tensor> sampleIdsTensor; | |||
| int64_t last_id = std::min(samples_per_tensor_ + next_id_, num_samples_); | |||
| RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIdsTensor, last_id - next_id_)); | |||
| auto id_ptr = sampleIdsTensor->begin<int64_t>(); | |||
| for (int64_t i = 0; i < (last_id - next_id_); i++) { | |||
| *(id_ptr + static_cast<ptrdiff_t>(i)) = (*sample_ids_)[i]; | |||
| } | |||
| next_id_ = last_id; | |||
| (*out) = {sampleIdsTensor}; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status MindRecordSamplerRT::InitSampler() { | |||
| sample_ids_ = shard_reader_->GetSampleIds(); | |||
| if (!sample_ids_) { | |||
| // Note, sample_ids_.empty() is okay and will just give no sample ids. | |||
| RETURN_STATUS_UNEXPECTED("ShardReader did not provide a valid sample ids vector via MindRecordSamplerRT"); | |||
| } | |||
| // Usually, the num samples is given from the user interface. In our case, that data is in mindrecord. | |||
| // Mindrecord already created the sample ids at this point, so the num samples is the size of the sampled id list. | |||
| num_samples_ = sample_ids_->size(); | |||
| return Status::OK(); | |||
| } | |||
| Status MindRecordSamplerRT::ResetSampler() { | |||
| // drive the shard reader reshuffle tasks to redo the sampling for another epoch | |||
| next_id_ = 0; | |||
| shard_reader_->ShuffleTask(); | |||
| return Status::OK(); | |||
| } | |||
| void MindRecordSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const { | |||
| out << "\nSampler: MindRecordSampler"; | |||
| if (show_all) { | |||
| // Call the super class for displaying any common detailed info | |||
| SamplerRT::SamplerPrint(out, show_all); | |||
| // Then add our own info if any | |||
| } | |||
| } | |||
| Status MindRecordSamplerRT::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["sampler_name"] = "MindRecordSampler"; | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,65 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_MINDRECORD_SAMPLER_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_MINDRECORD_SAMPLER_H_ | |||
| #include <limits> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/mindrecord/include/shard_reader.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class MindRecordSamplerRT : public SamplerRT { | |||
| public: | |||
| // Constructor | |||
| // @param shard_reader - shard_reader | |||
| // @param int64_t samples_per_tensor - Num of Sampler Ids to fetch via 1 GetNextSample call | |||
| MindRecordSamplerRT(mindrecord::ShardReader *shard_reader, | |||
| int64_t samples_per_tensor = std::numeric_limits<int64_t>::max()); | |||
| // Destructor. | |||
| ~MindRecordSamplerRT() = default; | |||
| // Op calls this to get next set of sampleIds | |||
| // @param out - Tensor of sample ids to be returned to caller | |||
| // @return Status The status code returned | |||
| Status GetNextSample(TensorRow *out) override; | |||
| // meant to be called by base class or python | |||
| Status InitSampler() override; | |||
| // for next epoch of sampleIds | |||
| // @return Status The status code returned | |||
| Status ResetSampler() override; | |||
| void SamplerPrint(std::ostream &out, bool show_all) const override; | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| mindrecord::ShardReader *shard_reader_; // back pointer to the shard reader | |||
| const std::vector<int> *sample_ids_; // read-only back pointer into mind record sampler ids | |||
| int64_t next_id_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_MINDRECORD_SAMPLER_H_ | |||
| @@ -23,6 +23,7 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| @@ -155,17 +156,19 @@ Status MindDataNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o | |||
| RETURN_IF_NOT_OK(BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_)); | |||
| std::shared_ptr<MindRecordOp> mindrecord_op; | |||
| std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>(); | |||
| // If pass a string to MindData(), it will be treated as a pattern to search for matched files, | |||
| // else if pass a vector to MindData(), it will be treated as specified files to be read | |||
| if (search_for_pattern_) { | |||
| std::vector<std::string> dataset_file_vec_ = {dataset_file_}; | |||
| mindrecord_op = | |||
| std::make_shared<MindRecordOp>(num_workers_, dataset_file_vec_, search_for_pattern_, connector_que_size_, | |||
| columns_list_, operators_, num_padded_, padded_sample_, sample_bytes_); | |||
| mindrecord_op = std::make_shared<MindRecordOp>(num_workers_, dataset_file_vec_, search_for_pattern_, | |||
| connector_que_size_, columns_list_, operators_, num_padded_, | |||
| padded_sample_, sample_bytes_, std::move(shard_reader)); | |||
| } else { | |||
| mindrecord_op = | |||
| std::make_shared<MindRecordOp>(num_workers_, dataset_files_, search_for_pattern_, connector_que_size_, | |||
| columns_list_, operators_, num_padded_, padded_sample_, sample_bytes_); | |||
| mindrecord_op = std::make_shared<MindRecordOp>(num_workers_, dataset_files_, search_for_pattern_, | |||
| connector_que_size_, columns_list_, operators_, num_padded_, | |||
| padded_sample_, sample_bytes_, std::move(shard_reader)); | |||
| } | |||
| RETURN_IF_NOT_OK(mindrecord_op->Init()); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -46,7 +46,7 @@ class __attribute__((visibility("default"))) ShardCategory : public ShardOperato | |||
| bool GetReplacement() const { return replacement_; } | |||
| MSRStatus Execute(ShardTask &tasks) override; | |||
| MSRStatus Execute(ShardTaskList &tasks) override; | |||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -39,15 +39,15 @@ class __attribute__((visibility("default"))) ShardDistributedSample : public Sha | |||
| ~ShardDistributedSample() override{}; | |||
| MSRStatus PreExecute(ShardTask &tasks) override; | |||
| MSRStatus PreExecute(ShardTaskList &tasks) override; | |||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | |||
| private: | |||
| bool shuffle_; | |||
| int no_of_padded_samples_; | |||
| bool first_epoch_; // check (num_sample + num_padded) % num_shards == 0 in first epoch | |||
| ShardTask task_; // maintain the input tasks in first epoch | |||
| bool first_epoch_; // check (num_sample + num_padded) % num_shards == 0 in first epoch | |||
| ShardTaskList task_; // maintain the input tasks in first epoch | |||
| }; | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -18,7 +18,7 @@ | |||
| #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ | |||
| #include <memory> | |||
| #include "minddata/mindrecord/include/shard_task.h" | |||
| #include "minddata/mindrecord/include/shard_task_list.h" | |||
| namespace mindspore { | |||
| namespace mindrecord { | |||
| @@ -26,7 +26,7 @@ class __attribute__((visibility("default"))) ShardOperator { | |||
| public: | |||
| virtual ~ShardOperator() = default; | |||
| MSRStatus operator()(ShardTask &tasks) { | |||
| MSRStatus operator()(ShardTaskList &tasks) { | |||
| if (SUCCESS != this->PreExecute(tasks)) { | |||
| return FAILED; | |||
| } | |||
| @@ -47,11 +47,11 @@ class __attribute__((visibility("default"))) ShardOperator { | |||
| virtual std::shared_ptr<ShardOperator> GetChildOp() { return child_op_; } | |||
| virtual MSRStatus PreExecute(ShardTask &tasks) { return SUCCESS; } | |||
| virtual MSRStatus PreExecute(ShardTaskList &tasks) { return SUCCESS; } | |||
| virtual MSRStatus Execute(ShardTask &tasks) = 0; | |||
| virtual MSRStatus Execute(ShardTaskList &tasks) = 0; | |||
| virtual MSRStatus SufExecute(ShardTask &tasks) { return SUCCESS; } | |||
| virtual MSRStatus SufExecute(ShardTaskList &tasks) { return SUCCESS; } | |||
| virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -38,7 +38,7 @@ class __attribute__((visibility("default"))) ShardPkSample : public ShardCategor | |||
| ~ShardPkSample() override{}; | |||
| MSRStatus SufExecute(ShardTask &tasks) override; | |||
| MSRStatus SufExecute(ShardTaskList &tasks) override; | |||
| int64_t GetNumSamples() const { return num_samples_; } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -210,6 +210,9 @@ class API_PUBLIC ShardReader { | |||
| /// \brief get the size of blob data | |||
| MSRStatus GetTotalBlobSize(int64_t *total_blob_size); | |||
| /// \brief get a read-only ptr to the sampled ids for this epoch | |||
| const std::vector<int> *GetSampleIds(); | |||
| protected: | |||
| /// \brief sqlite call back function | |||
| static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); | |||
| @@ -322,7 +325,7 @@ class API_PUBLIC ShardReader { | |||
| std::vector<std::string> selected_columns_; // columns which will be read | |||
| std::map<string, uint64_t> column_schema_id_; // column-schema map | |||
| std::vector<std::shared_ptr<ShardOperator>> operators_; // data operators, including shuffle, sample and category | |||
| ShardTask tasks_; // shard task | |||
| ShardTaskList tasks_; // shard task list | |||
| std::mutex shard_locker_; // locker of shard | |||
| // flags | |||
| @@ -339,7 +342,7 @@ class API_PUBLIC ShardReader { | |||
| std::mutex mtx_delivery_; // locker for delivery | |||
| std::condition_variable cv_delivery_; // conditional variable for delivery | |||
| std::condition_variable cv_iterator_; // conditional variable for iterator | |||
| std::atomic<int> task_id_; // task ID which is working | |||
| std::atomic<int> sample_id_position_; // index into the sample ids vector for the current sample id | |||
| std::atomic<int> deliver_id_; // delivery ID which is picked up by iterator | |||
| // map of delivery | |||
| std::unordered_map<int, std::shared_ptr<std::vector<std::tuple<std::vector<uint8_t>, json>>>> delivery_map_; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -40,11 +40,11 @@ class __attribute__((visibility("default"))) ShardSample : public ShardOperator | |||
| ~ShardSample() override{}; | |||
| MSRStatus Execute(ShardTask &tasks) override; | |||
| MSRStatus Execute(ShardTaskList &tasks) override; | |||
| MSRStatus UpdateTasks(ShardTask &tasks, int taking); | |||
| MSRStatus UpdateTasks(ShardTaskList &tasks, int taking); | |||
| MSRStatus SufExecute(ShardTask &tasks) override; | |||
| MSRStatus SufExecute(ShardTaskList &tasks) override; | |||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -33,7 +33,7 @@ class __attribute__((visibility("default"))) ShardSequentialSample : public Shar | |||
| ~ShardSequentialSample() override{}; | |||
| MSRStatus Execute(ShardTask &tasks) override; | |||
| MSRStatus Execute(ShardTaskList &tasks) override; | |||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -31,11 +31,14 @@ class __attribute__((visibility("default"))) ShardShuffle : public ShardOperator | |||
| ~ShardShuffle() override{}; | |||
| MSRStatus Execute(ShardTask &tasks) override; | |||
| MSRStatus Execute(ShardTaskList &tasks) override; | |||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | |||
| private: | |||
| // Private helper function | |||
| MSRStatus CategoryShuffle(ShardTaskList &tasks); | |||
| uint32_t shuffle_seed_; | |||
| int64_t no_of_samples_; | |||
| bool replacement_; | |||
| @@ -1,109 +0,0 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_ | |||
| #include <algorithm> | |||
| #include <iostream> | |||
| #include <string> | |||
| #include <tuple> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/mindrecord/include/common/shard_utils.h" | |||
| namespace mindspore { | |||
| namespace mindrecord { | |||
| class __attribute__((visibility("default"))) ShardTask { | |||
| public: | |||
| ShardTask(); | |||
| ShardTask(const ShardTask &task); // copy construction | |||
| ShardTask &operator=(const ShardTask &task); // assignment operator | |||
| ~ShardTask() = default; | |||
| void MakePerm(); | |||
| inline void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset, | |||
| const json &label); | |||
| inline void InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id, | |||
| const std::vector<uint64_t> &offset, const json &label); | |||
| inline void InsertTask(std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task); | |||
| inline void InsertTask(const uint32_t &i, | |||
| std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task); | |||
| void PopBack(); | |||
| uint32_t Size() const; | |||
| uint32_t SizeOfRows() const; | |||
| std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &GetTaskByID(size_t id); | |||
| std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &GetRandomTask(); | |||
| static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements, | |||
| int64_t num_samples); | |||
| inline void ResizeTask(const uint32_t &size); | |||
| uint32_t categories; | |||
| // The total sample ids which used to shuffle operation. The ids like: [0, 1, 2, 3, ..., n-1, n] | |||
| std::vector<int> permutation_; | |||
| // The data struct is as below: | |||
| // 1. TaskType: kCommonTask / kPaddedTask | |||
| // 2. std::tuple<int, int> : shard_id, group_id(fast load) / sample_id(lazy load) | |||
| // 3. std::vector<uint64_t>, json>> : [blob_start, blob_end], scalar_variable_fields | |||
| std::vector<std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json>> task_list_; | |||
| }; | |||
| inline void ShardTask::InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset, | |||
| const json &label) { | |||
| MS_LOG(DEBUG) << "Into insert task, shard_id: " << shard_id << ", group_id: " << group_id | |||
| << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << "."; | |||
| task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label); | |||
| } | |||
| inline void ShardTask::InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id, | |||
| const std::vector<uint64_t> &offset, const json &label) { | |||
| task_list_[i] = {task_type, std::make_tuple(shard_id, group_id), offset, label}; | |||
| } | |||
| inline void ShardTask::InsertTask(std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task) { | |||
| MS_LOG(DEBUG) << "Into insert task, shard_id: " << std::get<0>(std::get<1>(task)) | |||
| << ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump() | |||
| << ", size of task_list_: " << task_list_.size() << "."; | |||
| task_list_.push_back(std::move(task)); | |||
| } | |||
| inline void ShardTask::InsertTask(const uint32_t &i, | |||
| std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task) { | |||
| task_list_[i] = std::move(task); | |||
| } | |||
| inline void ShardTask::ResizeTask(const uint32_t &size) { task_list_.resize(size); } | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_ | |||
| @@ -0,0 +1,132 @@ | |||
| /** | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_ | |||
| #include <algorithm> | |||
| #include <iostream> | |||
| #include <string> | |||
| #include <tuple> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/mindrecord/include/common/shard_utils.h" | |||
| namespace mindspore { | |||
| namespace mindrecord { | |||
| // The data struct is as below: | |||
| // 1. TaskType: kCommonTask / kPaddedTask | |||
| // 2. std::tuple<int, int> : shard_id, group_id(fast load) / sample_id(lazy load) | |||
| // 3. std::vector<uint64_t>, json>> : [blob_start, blob_end], scalar_variable_fields | |||
| using ShardTask = std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json>; | |||
| class __attribute__((visibility("default"))) ShardTaskList { | |||
| public: | |||
| ShardTaskList(); | |||
| ShardTaskList(const ShardTaskList &task); // copy construction | |||
| ShardTaskList &operator=(const ShardTaskList &task); // assignment operator | |||
| ~ShardTaskList() = default; | |||
| void InitSampleIds(); | |||
| static void TaskListSwap(ShardTaskList &orig_tasks, ShardTaskList &new_tasks); | |||
| // Assigns the task based on task id | |||
| inline void AssignTask(ShardTaskList &sourceTasks, size_t id); | |||
| inline void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset, | |||
| const json &label); | |||
| inline void InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id, | |||
| const std::vector<uint64_t> &offset, const json &label); | |||
| inline void InsertTask(ShardTask task); | |||
| inline void InsertTask(const uint32_t &i, ShardTask task); | |||
| void MakePerm(); | |||
| inline void InsertSampleId(int id); | |||
| void PopBack(); | |||
| uint32_t Size() const; | |||
| uint32_t SizeOfRows() const; | |||
| ShardTask &GetTaskByID(size_t id); | |||
| ShardTask &GetRandomTask(); | |||
| int GetTaskSampleByID(size_t id); | |||
| int GetRandomTaskID(); | |||
| static ShardTaskList Combine(std::vector<ShardTaskList> &category_tasks, bool replacement, int64_t num_elements, | |||
| int64_t num_samples); | |||
| inline void ResizeTask(const uint32_t &size); | |||
| uint32_t categories; | |||
| std::vector<int> permutation_; // A list of ints used for shuffling sample ids | |||
| std::vector<int> sample_ids_; // The list of actual ids that were sampled | |||
| std::vector<ShardTask> task_list_; // The full list of tasks | |||
| }; | |||
| inline void ShardTaskList::AssignTask(ShardTaskList &sourceTasks, size_t id) { | |||
| // Insert the sample id from the source into ourself by indexing at id position. | |||
| // Important: The task list itself does not change. | |||
| int sample_id = sourceTasks.GetTaskSampleByID(id); | |||
| MS_LOG(DEBUG) << "Insert sample id (" << sample_id << ") into task list from source task position: " << id; | |||
| sample_ids_.push_back(sample_id); | |||
| } | |||
| inline void ShardTaskList::InsertTask(TaskType task_type, int shard_id, int group_id, | |||
| const std::vector<uint64_t> &offset, const json &label) { | |||
| MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << shard_id << ", group_id: " << group_id | |||
| << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << "."; | |||
| task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label); | |||
| } | |||
| inline void ShardTaskList::InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id, | |||
| const std::vector<uint64_t> &offset, const json &label) { | |||
| MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << shard_id << ", group_id: " << group_id | |||
| << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << "."; | |||
| task_list_[i] = {task_type, std::make_tuple(shard_id, group_id), offset, label}; | |||
| } | |||
| inline void ShardTaskList::InsertTask(ShardTask task) { | |||
| MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << std::get<0>(std::get<1>(task)) | |||
| << ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump() | |||
| << ", size of task_list_: " << task_list_.size() << "."; | |||
| task_list_.push_back(std::move(task)); | |||
| } | |||
| inline void ShardTaskList::InsertTask(const uint32_t &i, ShardTask task) { task_list_[i] = std::move(task); } | |||
| inline void ShardTaskList::ResizeTask(const uint32_t &size) { task_list_.resize(size); } | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_ | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -46,7 +46,7 @@ ShardReader::ShardReader() | |||
| num_padded_(0), | |||
| num_rows_(0), | |||
| total_blob_size_(0), | |||
| task_id_(0), | |||
| sample_id_position_(0), | |||
| deliver_id_(0), | |||
| lazy_load_(false), | |||
| shard_sample_count_() {} | |||
| @@ -1088,9 +1088,8 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::shared_ptr<ShardOperator | |||
| i++; | |||
| } | |||
| } | |||
| // Generate task list, a task will create a batch | |||
| std::vector<ShardTask> categoryTasks(categories.size()); | |||
| // Generate a vector of task lists. Each catogory has a list of tasks. | |||
| std::vector<ShardTaskList> categoryTasks(categories.size()); | |||
| for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) { | |||
| int category_index = 0; | |||
| for (int shard_id = 0; shard_id < shard_count_ && category_index < num_elements; ++shard_id) { | |||
| @@ -1122,7 +1121,9 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::shared_ptr<ShardOperator | |||
| } | |||
| } | |||
| } | |||
| tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples); | |||
| tasks_ = ShardTaskList::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples); | |||
| tasks_.InitSampleIds(); | |||
| if (SUCCESS != (*category_op)(tasks_)) { | |||
| return FAILED; | |||
| } | |||
| @@ -1246,6 +1247,10 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "Created initial list of tasks. There are " << tasks_.Size() << " to start with before sampling."; | |||
| tasks_.InitSampleIds(); | |||
| for (uint32_t operator_no = 0; operator_no < operators.size(); operator_no++) { | |||
| const auto &op = operators[operator_no]; | |||
| if (std::dynamic_pointer_cast<ShardCategory>(op)) continue; | |||
| @@ -1256,7 +1261,9 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u | |||
| if (tasks_.permutation_.empty()) tasks_.MakePerm(); | |||
| num_rows_ = tasks_.Size(); | |||
| MS_LOG(INFO) << "Total rows is " << num_rows_; | |||
| MS_LOG(INFO) << "Total rows is " << num_rows_ | |||
| << " and total amount sampled initially is: " << tasks_.sample_ids_.size(); | |||
| return SUCCESS; | |||
| } | |||
| @@ -1272,9 +1279,9 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ | |||
| uint32_t blob_start = 0; | |||
| uint32_t blob_end = 0; | |||
| json var_fields; | |||
| // Pick up task from task list | |||
| auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]); | |||
| ShardTask task; | |||
| task = tasks_.GetTaskByID(task_id); | |||
| // check task type | |||
| auto task_type = std::get<0>(task); | |||
| @@ -1354,16 +1361,16 @@ MSRStatus ShardReader::ConsumerByRow(int consumer_id) { | |||
| // Loop forever | |||
| for (;;) { | |||
| int task_id = 0; | |||
| int sample_id_pos = 0; | |||
| // Get next task ID | |||
| task_id = task_id_++; | |||
| sample_id_pos = sample_id_position_++; | |||
| // All tasks are done | |||
| if (task_id >= static_cast<int>(tasks_.Size())) { | |||
| if (sample_id_pos >= static_cast<int>(tasks_.sample_ids_.size())) { | |||
| return FAILED; | |||
| } | |||
| const auto &ret = ConsumerOneTask(task_id, consumer_id); | |||
| const auto &ret = ConsumerOneTask(tasks_.sample_ids_[sample_id_pos], consumer_id); | |||
| if (SUCCESS != ret.first) { | |||
| return FAILED; | |||
| } | |||
| @@ -1372,11 +1379,13 @@ MSRStatus ShardReader::ConsumerByRow(int consumer_id) { | |||
| // otherwise, set batch data in map | |||
| { | |||
| std::unique_lock<std::mutex> lck(mtx_delivery_); | |||
| cv_delivery_.wait(lck, [task_id, this] { return interrupt_ || task_id <= deliver_id_ + kNumBatchInMap; }); | |||
| cv_delivery_.wait(lck, | |||
| [sample_id_pos, this] { return interrupt_ || sample_id_pos <= deliver_id_ + kNumBatchInMap; }); | |||
| if (interrupt_) { | |||
| return SUCCESS; | |||
| } | |||
| delivery_map_[task_id] = std::make_shared<std::vector<std::tuple<std::vector<uint8_t>, json>>>(std::move(batch)); | |||
| delivery_map_[sample_id_pos] = | |||
| std::make_shared<std::vector<std::tuple<std::vector<uint8_t>, json>>>(std::move(batch)); | |||
| } | |||
| cv_iterator_.notify_one(); | |||
| } | |||
| @@ -1386,7 +1395,7 @@ std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetNext() { | |||
| if (interrupt_) { | |||
| return std::vector<std::tuple<std::vector<uint8_t>, json>>(); | |||
| } | |||
| if (deliver_id_ >= static_cast<int>(tasks_.Size())) { | |||
| if (deliver_id_ >= static_cast<int>(tasks_.sample_ids_.size())) { | |||
| return std::vector<std::tuple<std::vector<uint8_t>, json>>(); | |||
| } | |||
| @@ -1458,7 +1467,7 @@ std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> Sha | |||
| void ShardReader::Reset() { | |||
| { | |||
| std::lock_guard<std::mutex> lck(mtx_delivery_); | |||
| task_id_ = 0; | |||
| sample_id_position_ = 0; | |||
| deliver_id_ = 0; | |||
| } | |||
| cv_delivery_.notify_all(); | |||
| @@ -1486,5 +1495,10 @@ void ShardReader::ShuffleTask() { | |||
| if (tasks_.permutation_.empty()) tasks_.MakePerm(); | |||
| } | |||
| const std::vector<int> *ShardReader::GetSampleIds() { | |||
| // return const reference to private sample id list. | |||
| return &(this->tasks_.sample_ids_); | |||
| } | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -34,7 +34,7 @@ ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elem | |||
| num_categories_(num_categories), | |||
| replacement_(replacement) {} | |||
| MSRStatus ShardCategory::Execute(ShardTask &tasks) { return SUCCESS; } | |||
| MSRStatus ShardCategory::Execute(ShardTaskList &tasks) { return SUCCESS; } | |||
| int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | |||
| if (dataset_size == 0) return dataset_size; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -55,7 +55,7 @@ int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_ | |||
| return 0; | |||
| } | |||
| MSRStatus ShardDistributedSample::PreExecute(ShardTask &tasks) { | |||
| MSRStatus ShardDistributedSample::PreExecute(ShardTaskList &tasks) { | |||
| auto total_no = tasks.Size(); | |||
| if (no_of_padded_samples_ > 0 && first_epoch_) { | |||
| if (total_no % denominator_ != 0) { | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -37,7 +37,7 @@ ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elem | |||
| shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement | |||
| } | |||
| MSRStatus ShardPkSample::SufExecute(ShardTask &tasks) { | |||
| MSRStatus ShardPkSample::SufExecute(ShardTaskList &tasks) { | |||
| if (shuffle_ == true) { | |||
| if (SUCCESS != (*shuffle_op_)(tasks)) { | |||
| return FAILED; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -80,21 +80,21 @@ int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | |||
| return 0; | |||
| } | |||
| MSRStatus ShardSample::UpdateTasks(ShardTask &tasks, int taking) { | |||
| MSRStatus ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) { | |||
| if (tasks.permutation_.empty()) { | |||
| ShardTask new_tasks; | |||
| int total_no = static_cast<int>(tasks.Size()); | |||
| ShardTaskList new_tasks; | |||
| int total_no = static_cast<int>(tasks.sample_ids_.size()); | |||
| if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) { | |||
| for (int i = 0; i < indices_.size(); ++i) { | |||
| int index = ((indices_[i] % total_no) + total_no) % total_no; | |||
| new_tasks.InsertTask(tasks.GetTaskByID(index)); // different mod result between c and python | |||
| new_tasks.AssignTask(tasks, index); // different mod result between c and python | |||
| } | |||
| } else { | |||
| int count = 0; | |||
| if (nums_per_shard_.empty()) { | |||
| for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { | |||
| if (no_of_samples_ != 0 && count == no_of_samples_) break; | |||
| new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); // rounding up. if overflow, go back to start | |||
| new_tasks.AssignTask(tasks, i % total_no); // rounding up. if overflow, go back to start | |||
| count++; | |||
| } | |||
| } else { | |||
| @@ -102,33 +102,33 @@ MSRStatus ShardSample::UpdateTasks(ShardTask &tasks, int taking) { | |||
| size_t i = partition_id_ - 1 >= 0 ? nums_per_shard_[partition_id_ - 1] : 0; | |||
| for (; i < nums_per_shard_[partition_id_]; i++) { | |||
| if (no_of_samples_ != 0 && count == no_of_samples_) break; | |||
| new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); | |||
| new_tasks.AssignTask(tasks, i % total_no); | |||
| count++; | |||
| } | |||
| } | |||
| } | |||
| std::swap(tasks, new_tasks); | |||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | |||
| } else { | |||
| ShardTask new_tasks; | |||
| if (taking > static_cast<int>(tasks.permutation_.size())) { | |||
| ShardTaskList new_tasks; | |||
| if (taking > static_cast<int>(tasks.sample_ids_.size())) { | |||
| return FAILED; | |||
| } | |||
| int total_no = static_cast<int>(tasks.permutation_.size()); | |||
| int count = 0; | |||
| for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { | |||
| if (no_of_samples_ != 0 && count == no_of_samples_) break; | |||
| new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); | |||
| new_tasks.AssignTask(tasks, tasks.permutation_[i % total_no]); | |||
| count++; | |||
| } | |||
| std::swap(tasks, new_tasks); | |||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| MSRStatus ShardSample::Execute(ShardTask &tasks) { | |||
| MSRStatus ShardSample::Execute(ShardTaskList &tasks) { | |||
| if (offset_ != -1) { | |||
| int64_t old_v = 0; | |||
| int num_rows_ = static_cast<int>(tasks.Size()); | |||
| int num_rows_ = static_cast<int>(tasks.sample_ids_.size()); | |||
| for (int x = 0; x < denominator_; x++) { | |||
| int samples_per_buffer_ = (num_rows_ + offset_) / denominator_; | |||
| int remainder = (num_rows_ + offset_) % denominator_; | |||
| @@ -140,8 +140,7 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) { | |||
| } | |||
| } | |||
| int no_of_categories = static_cast<int>(tasks.categories); | |||
| int total_no = static_cast<int>(tasks.Size()); // make sure task_size | |||
| int total_no = static_cast<int>(tasks.sample_ids_.size()); | |||
| int taking = 0; | |||
| if (sampler_type_ == kCustomTopNSampler) { // non sharding case constructor #1 | |||
| no_of_samples_ = std::min(no_of_samples_, total_no); | |||
| @@ -167,7 +166,7 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) { | |||
| return UpdateTasks(tasks, taking); | |||
| } | |||
| MSRStatus ShardSample::SufExecute(ShardTask &tasks) { | |||
| MSRStatus ShardSample::SufExecute(ShardTaskList &tasks) { | |||
| if (sampler_type_ == kSubsetRandomSampler) { | |||
| if (SUCCESS != (*shuffle_op_)(tasks)) { | |||
| return FAILED; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -38,9 +38,9 @@ int64_t ShardSequentialSample::GetNumSamples(int64_t dataset_size, int64_t num_c | |||
| return std::min(static_cast<int64_t>(no_of_samples_), dataset_size); | |||
| } | |||
| MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) { | |||
| int64_t total_no = static_cast<int64_t>(tasks.Size()); | |||
| MSRStatus ShardSequentialSample::Execute(ShardTaskList &tasks) { | |||
| int64_t taking; | |||
| int64_t total_no = static_cast<int64_t>(tasks.sample_ids_.size()); | |||
| if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { | |||
| taking = total_no; | |||
| } else if (per_ > kEpsilon && per_ <= 1.0f) { | |||
| @@ -50,22 +50,22 @@ MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) { | |||
| } | |||
| if (tasks.permutation_.empty()) { | |||
| ShardTask new_tasks; | |||
| ShardTaskList new_tasks; | |||
| total_no = static_cast<int64_t>(tasks.Size()); | |||
| for (size_t i = offset_; i < taking + offset_; ++i) { | |||
| new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); | |||
| new_tasks.AssignTask(tasks, i % total_no); | |||
| } | |||
| std::swap(tasks, new_tasks); | |||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | |||
| } else { // shuffled | |||
| ShardTask new_tasks; | |||
| ShardTaskList new_tasks; | |||
| if (taking > static_cast<int64_t>(tasks.permutation_.size())) { | |||
| return FAILED; | |||
| } | |||
| total_no = static_cast<int64_t>(tasks.permutation_.size()); | |||
| for (size_t i = offset_; i < taking + offset_; ++i) { | |||
| new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); | |||
| new_tasks.AssignTask(tasks, tasks.permutation_[i % total_no]); | |||
| } | |||
| std::swap(tasks, new_tasks); | |||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -42,7 +42,31 @@ int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | |||
| return no_of_samples_ == 0 ? dataset_size : std::min(dataset_size, no_of_samples_); | |||
| } | |||
| MSRStatus ShardShuffle::Execute(ShardTask &tasks) { | |||
| MSRStatus ShardShuffle::CategoryShuffle(ShardTaskList &tasks) { | |||
| uint32_t individual_size; | |||
| individual_size = tasks.sample_ids_.size() / tasks.categories; | |||
| std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size)); | |||
| for (uint32_t i = 0; i < tasks.categories; i++) { | |||
| for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast<int>(j); | |||
| std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_)); | |||
| } | |||
| tasks.permutation_.clear(); // Jamie replace this we setting flag to false or something | |||
| for (uint32_t j = 0; j < individual_size; j++) { | |||
| for (uint32_t i = 0; i < tasks.categories; i++) { | |||
| tasks.permutation_.push_back(new_permutations[i][j] * static_cast<int>(tasks.categories) + static_cast<int>(i)); | |||
| } | |||
| } | |||
| ShardTaskList new_tasks; | |||
| for (size_t i = 0; i < individual_size; ++i) { | |||
| new_tasks.AssignTask(tasks, tasks.permutation_[i]); | |||
| } | |||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | |||
| return SUCCESS; | |||
| } | |||
| MSRStatus ShardShuffle::Execute(ShardTaskList &tasks) { | |||
| if (reshuffle_each_epoch_) shuffle_seed_++; | |||
| if (tasks.categories < 1) { | |||
| return FAILED; | |||
| @@ -52,43 +76,31 @@ MSRStatus ShardShuffle::Execute(ShardTask &tasks) { | |||
| tasks.MakePerm(); | |||
| } | |||
| if (replacement_ == true) { | |||
| ShardTask new_tasks; | |||
| if (no_of_samples_ == 0) { | |||
| no_of_samples_ = static_cast<int>(tasks.Size()); | |||
| } | |||
| ShardTaskList new_tasks; | |||
| if (no_of_samples_ == 0) no_of_samples_ = static_cast<int>(tasks.sample_ids_.size()); | |||
| if (no_of_samples_ <= 0) { | |||
| MS_LOG(ERROR) << "no_of_samples need to be positive."; | |||
| return FAILED; | |||
| } | |||
| new_tasks.task_list_.reserve(no_of_samples_); | |||
| for (uint32_t i = 0; i < no_of_samples_; ++i) { | |||
| new_tasks.InsertTask(tasks.GetRandomTask()); | |||
| new_tasks.AssignTask(tasks, tasks.GetRandomTaskID()); | |||
| } | |||
| std::swap(tasks, new_tasks); | |||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | |||
| } else { | |||
| std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_)); | |||
| auto total_no = static_cast<int64_t>(tasks.Size()); | |||
| if (no_of_samples_ > 0 && no_of_samples_ < total_no) { | |||
| ShardTask new_tasks; | |||
| for (size_t i = 0; i < no_of_samples_; ++i) { | |||
| new_tasks.InsertTask(tasks.GetTaskByID(i)); | |||
| } | |||
| std::swap(tasks, new_tasks); | |||
| ShardTaskList new_tasks; | |||
| size_t samples_to_assign = | |||
| (no_of_samples_ > 0 && no_of_samples_ < total_no) ? no_of_samples_ : tasks.sample_ids_.size(); | |||
| for (size_t i = 0; i < samples_to_assign; ++i) { | |||
| new_tasks.AssignTask(tasks, tasks.permutation_[i]); | |||
| } | |||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | |||
| } | |||
| } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn) | |||
| uint32_t individual_size = tasks.Size() / tasks.categories; | |||
| std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size)); | |||
| for (uint32_t i = 0; i < tasks.categories; i++) { | |||
| for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast<int>(j); | |||
| std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_)); | |||
| } | |||
| tasks.permutation_.clear(); | |||
| for (uint32_t j = 0; j < individual_size; j++) { | |||
| for (uint32_t i = 0; i < tasks.categories; i++) { | |||
| tasks.permutation_.push_back(new_permutations[i][j] * static_cast<int>(tasks.categories) + static_cast<int>(i)); | |||
| } | |||
| } | |||
| return this->CategoryShuffle(tasks); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -15,7 +15,7 @@ | |||
| */ | |||
| #include "minddata/dataset/util/random.h" | |||
| #include "minddata/mindrecord/include/shard_task.h" | |||
| #include "minddata/mindrecord/include/shard_task_list.h" | |||
| #include "utils/ms_utils.h" | |||
| #include "minddata/mindrecord/include/common/shard_utils.h" | |||
| @@ -25,55 +25,88 @@ using mindspore::MsLogLevel::DEBUG; | |||
| namespace mindspore { | |||
| namespace mindrecord { | |||
| ShardTask::ShardTask() : categories(1) {} | |||
| ShardTaskList::ShardTaskList() : categories(1) {} | |||
| ShardTask::ShardTask(const ShardTask &other) | |||
| : categories(other.categories), permutation_(other.permutation_), task_list_(other.task_list_) {} | |||
| ShardTaskList::ShardTaskList(const ShardTaskList &other) | |||
| : categories(other.categories), | |||
| permutation_(other.permutation_), | |||
| sample_ids_(other.sample_ids_), | |||
| task_list_(other.task_list_) {} | |||
| ShardTask &ShardTask::operator=(const ShardTask &other) { | |||
| ShardTask tmp(other); | |||
| ShardTaskList &ShardTaskList::operator=(const ShardTaskList &other) { | |||
| ShardTaskList tmp(other); | |||
| std::swap(categories, tmp.categories); | |||
| permutation_.swap(tmp.permutation_); | |||
| sample_ids_.swap(tmp.sample_ids_); | |||
| task_list_.swap(tmp.task_list_); | |||
| return *this; | |||
| } | |||
| void ShardTask::MakePerm() { | |||
| permutation_ = std::vector<int>(task_list_.size()); | |||
| for (uint32_t i = 0; i < task_list_.size(); i++) { | |||
| void ShardTaskList::InitSampleIds() { | |||
| // no-op if there already exists sample ids. Do not clobber previous list | |||
| if (sample_ids_.empty()) { | |||
| sample_ids_ = std::vector<int>(task_list_.size()); | |||
| for (int i = 0; i < task_list_.size(); i++) sample_ids_[i] = i; | |||
| } | |||
| } | |||
| void ShardTaskList::MakePerm() { | |||
| size_t perm_size = sample_ids_.size(); | |||
| permutation_ = std::vector<int>(perm_size); | |||
| for (uint32_t i = 0; i < perm_size; i++) { | |||
| permutation_[i] = static_cast<int>(i); | |||
| } | |||
| } | |||
| void ShardTask::PopBack() { task_list_.pop_back(); } | |||
| // Swap the new_tasks with orig_tasks | |||
| void ShardTaskList::TaskListSwap(ShardTaskList &orig_tasks, ShardTaskList &new_tasks) { | |||
| // When swapping, if the orig_tasks contains fields that need to be preserved after the swap, then swapping with a | |||
| // new_tasks that does not have those fields will result in clobbering/losing the data after the swap. | |||
| // The task_list_ should not be lost/clobbered. | |||
| new_tasks.task_list_ = std::move(orig_tasks.task_list_); | |||
| uint32_t ShardTask::Size() const { return static_cast<uint32_t>(task_list_.size()); } | |||
| // Now, it's safe to drive the swap. | |||
| std::swap(orig_tasks, new_tasks); | |||
| } | |||
| uint32_t ShardTask::SizeOfRows() const { | |||
| void ShardTaskList::PopBack() { task_list_.pop_back(); } | |||
| uint32_t ShardTaskList::Size() const { return static_cast<uint32_t>(task_list_.size()); } | |||
| uint32_t ShardTaskList::SizeOfRows() const { | |||
| if (task_list_.size() == 0) return static_cast<uint32_t>(0); | |||
| // 1 task is 1 page | |||
| auto sum_num_rows = [](int x, std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> y) { | |||
| return x + std::get<2>(y)[0]; | |||
| }; | |||
| auto sum_num_rows = [](int x, ShardTask y) { return x + std::get<2>(y)[0]; }; | |||
| uint32_t nRows = std::accumulate(task_list_.begin(), task_list_.end(), 0, sum_num_rows); | |||
| return nRows; | |||
| } | |||
| std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::GetTaskByID(size_t id) { | |||
| ShardTask &ShardTaskList::GetTaskByID(size_t id) { | |||
| MS_ASSERT(id < task_list_.size()); | |||
| return task_list_[id]; | |||
| } | |||
| std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::GetRandomTask() { | |||
| int ShardTaskList::GetTaskSampleByID(size_t id) { | |||
| MS_ASSERT(id < sample_ids_.size()); | |||
| return sample_ids_[id]; | |||
| } | |||
| int ShardTaskList::GetRandomTaskID() { | |||
| std::mt19937 gen = mindspore::dataset::GetRandomDevice(); | |||
| std::uniform_int_distribution<> dis(0, task_list_.size() - 1); | |||
| return dis(gen); | |||
| } | |||
| ShardTask &ShardTaskList::GetRandomTask() { | |||
| std::mt19937 gen = mindspore::dataset::GetRandomDevice(); | |||
| std::uniform_int_distribution<> dis(0, task_list_.size() - 1); | |||
| return task_list_[dis(gen)]; | |||
| } | |||
| ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements, | |||
| int64_t num_samples) { | |||
| ShardTask res; | |||
| ShardTaskList ShardTaskList::Combine(std::vector<ShardTaskList> &category_tasks, bool replacement, int64_t num_elements, | |||
| int64_t num_samples) { | |||
| ShardTaskList res; | |||
| if (category_tasks.empty()) return res; | |||
| auto total_categories = category_tasks.size(); | |||
| res.categories = static_cast<uint32_t>(total_categories); | |||
| @@ -107,6 +140,7 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac | |||
| } | |||
| } | |||
| } | |||
| return res; | |||
| } | |||
| } // namespace mindrecord | |||