| @@ -29,7 +29,7 @@ std::shared_ptr<ConfigManager> _config = GlobalContext::config_manager(); | |||||
| // Function to set the seed to be used in any random generator | // Function to set the seed to be used in any random generator | ||||
| bool set_seed(int32_t seed) { | bool set_seed(int32_t seed) { | ||||
| if (seed < 0 || seed > UINT32_MAX) { | |||||
| if (seed < 0 || seed > INT32_MAX) { | |||||
| MS_LOG(ERROR) << "Seed given is not within the required range: " << seed; | MS_LOG(ERROR) << "Seed given is not within the required range: " << seed; | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -68,7 +68,7 @@ int32_t get_num_parallel_workers() { return _config->num_parallel_workers(); } | |||||
| // Function to set the default interval (in milliseconds) for monitor sampling | // Function to set the default interval (in milliseconds) for monitor sampling | ||||
| bool set_monitor_sampling_interval(int32_t interval) { | bool set_monitor_sampling_interval(int32_t interval) { | ||||
| if (interval <= 0 || interval > UINT32_MAX) { | |||||
| if (interval <= 0 || interval > INT32_MAX) { | |||||
| MS_LOG(ERROR) << "Interval given is not within the required range: " << interval; | MS_LOG(ERROR) << "Interval given is not within the required range: " << interval; | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -81,7 +81,7 @@ int32_t get_monitor_sampling_interval() { return _config->monitor_sampling_inter | |||||
| // Function to set the default timeout (in seconds) for DSWaitedCallback | // Function to set the default timeout (in seconds) for DSWaitedCallback | ||||
| bool set_callback_timeback(int32_t timeout) { | bool set_callback_timeback(int32_t timeout) { | ||||
| if (timeout <= 0 || timeout > UINT32_MAX) { | |||||
| if (timeout <= 0 || timeout > INT32_MAX) { | |||||
| MS_LOG(ERROR) << "Timeout given is not within the required range: " << timeout; | MS_LOG(ERROR) << "Timeout given is not within the required range: " << timeout; | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -539,7 +539,7 @@ ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) { | |||||
| } | } | ||||
| #endif | #endif | ||||
| int64_t Dataset::GetBatchSize() { | int64_t Dataset::GetBatchSize() { | ||||
| int64_t batch_size; | |||||
| int64_t batch_size = -1; | |||||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | ||||
| RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1); | RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1); | ||||
| RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1); | RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1); | ||||
| @@ -548,7 +548,7 @@ int64_t Dataset::GetBatchSize() { | |||||
| } | } | ||||
| int64_t Dataset::GetRepeatCount() { | int64_t Dataset::GetRepeatCount() { | ||||
| int64_t repeat_count; | |||||
| int64_t repeat_count = 0; | |||||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | ||||
| RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1); | RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1); | ||||
| RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), 0); | RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), 0); | ||||
| @@ -227,7 +227,9 @@ Status SaveToDisk::Save() { | |||||
| nlohmann::json row_raw_data; | nlohmann::json row_raw_data; | ||||
| std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> row_bin_data; | std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> row_bin_data; | ||||
| RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row)); | RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row)); | ||||
| if (row.empty()) break; | |||||
| if (row.empty()) { | |||||
| break; | |||||
| } | |||||
| if (first_loop) { | if (first_loop) { | ||||
| nlohmann::json mr_json; | nlohmann::json mr_json; | ||||
| std::vector<std::string> index_fields; | std::vector<std::string> index_fields; | ||||
| @@ -249,7 +251,7 @@ Status SaveToDisk::Save() { | |||||
| raw_data.insert( | raw_data.insert( | ||||
| std::pair<uint64_t, std::vector<nlohmann::json>>(mr_schema_id, std::vector<nlohmann::json>{row_raw_data})); | std::pair<uint64_t, std::vector<nlohmann::json>>(mr_schema_id, std::vector<nlohmann::json>{row_raw_data})); | ||||
| std::vector<std::vector<uint8_t>> bin_data; | std::vector<std::vector<uint8_t>> bin_data; | ||||
| if (nullptr != output_bin_data) { | |||||
| if (output_bin_data != nullptr) { | |||||
| bin_data.emplace_back(*output_bin_data); | bin_data.emplace_back(*output_bin_data); | ||||
| } | } | ||||
| mr_writer->WriteRawData(raw_data, bin_data); | mr_writer->WriteRawData(raw_data, bin_data); | ||||
| @@ -60,7 +60,7 @@ Status IteratorBase::GetNextAsMap(TensorMap *out_map) { | |||||
| } | } | ||||
| // Populate the out map from the row and return it | // Populate the out map from the row and return it | ||||
| for (auto colMap : col_name_id_map_) { | |||||
| for (const auto colMap : col_name_id_map_) { | |||||
| (*out_map)[colMap.first] = std::move(curr_row[colMap.second]); | (*out_map)[colMap.first] = std::move(curr_row[colMap.second]); | ||||
| } | } | ||||
| @@ -197,7 +197,7 @@ Status DatasetIterator::GetOutputShapes(std::vector<TensorShape> *out_shapes) { | |||||
| if (device_queue_row_.empty()) { | if (device_queue_row_.empty()) { | ||||
| RETURN_IF_NOT_OK(FetchNextTensorRow(&device_queue_row_)); | RETURN_IF_NOT_OK(FetchNextTensorRow(&device_queue_row_)); | ||||
| } | } | ||||
| for (auto ts : device_queue_row_) { | |||||
| for (const auto ts : device_queue_row_) { | |||||
| out_shapes->push_back(ts->shape()); | out_shapes->push_back(ts->shape()); | ||||
| } | } | ||||
| @@ -211,7 +211,7 @@ Status DatasetIterator::GetOutputTypes(std::vector<DataType> *out_types) { | |||||
| if (device_queue_row_.empty()) { | if (device_queue_row_.empty()) { | ||||
| RETURN_IF_NOT_OK(FetchNextTensorRow(&device_queue_row_)); | RETURN_IF_NOT_OK(FetchNextTensorRow(&device_queue_row_)); | ||||
| } | } | ||||
| for (auto ts : device_queue_row_) { | |||||
| for (const auto ts : device_queue_row_) { | |||||
| out_types->push_back(ts->type()); | out_types->push_back(ts->type()); | ||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -81,7 +81,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| /// Constructor | /// Constructor | ||||
| /// \param op_connector_size - The size for the output connector of this operator. | /// \param op_connector_size - The size for the output connector of this operator. | ||||
| /// \param sampler - The sampler for the op | /// \param sampler - The sampler for the op | ||||
| explicit DatasetOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler); | |||||
| DatasetOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler); | |||||
| /// Destructor | /// Destructor | ||||
| virtual ~DatasetOp() { tree_ = nullptr; } | virtual ~DatasetOp() { tree_ = nullptr; } | ||||
| @@ -26,7 +26,7 @@ namespace dataset { | |||||
| CpuMapJob::CpuMapJob() = default; | CpuMapJob::CpuMapJob() = default; | ||||
| // Constructor | // Constructor | ||||
| CpuMapJob::CpuMapJob(std::vector<std::shared_ptr<TensorOp>> operations) : MapJob(operations) {} | |||||
| CpuMapJob::CpuMapJob(std::vector<std::shared_ptr<TensorOp>> operations) : MapJob(std::move(operations)) {} | |||||
| // Destructor | // Destructor | ||||
| CpuMapJob::~CpuMapJob() = default; | CpuMapJob::~CpuMapJob() = default; | ||||
| @@ -19,9 +19,9 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/kernels/tensor_op.h" | |||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/core/tensor_row.h" | #include "minddata/dataset/core/tensor_row.h" | ||||
| #include "minddata/dataset/kernels/tensor_op.h" | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -18,9 +18,9 @@ | |||||
| #include <cstring> | #include <cstring> | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/core/config_manager.h" | |||||
| #include "minddata/dataset/callback/callback_param.h" | #include "minddata/dataset/callback/callback_param.h" | ||||
| #include "minddata/dataset/core/config_manager.h" | |||||
| #include "minddata/dataset/core/constants.h" | #include "minddata/dataset/core/constants.h" | ||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | #include "minddata/dataset/engine/data_buffer.h" | ||||
| @@ -44,7 +44,7 @@ MapOp::Builder::Builder() { | |||||
| Status MapOp::Builder::sanityCheck() const { | Status MapOp::Builder::sanityCheck() const { | ||||
| if (build_tensor_funcs_.empty()) { | if (build_tensor_funcs_.empty()) { | ||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, | return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, | ||||
| "Building a MapOp that has not provided any function/operation to apply"); | |||||
| "Building a MapOp without providing any function/operation to apply"); | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -121,26 +121,13 @@ Status MapOp::GenerateWorkerJob(const std::unique_ptr<MapWorkerJob> *worker_job) | |||||
| // In the future, we will have heuristic or control from user to select target device | // In the future, we will have heuristic or control from user to select target device | ||||
| MapTargetDevice target_device = MapTargetDevice::kCpu; | MapTargetDevice target_device = MapTargetDevice::kCpu; | ||||
| switch (target_device) { | |||||
| case MapTargetDevice::kCpu: | |||||
| // If there is no existing map_job, we will create one. | |||||
| // map_job could be nullptr when we are at the first tensor op or when the target device of the prev op | |||||
| // is different with that of the current op. | |||||
| if (map_job == nullptr) { | |||||
| map_job = std::make_shared<CpuMapJob>(); | |||||
| } | |||||
| map_job->AddOperation(tfuncs_[i]); | |||||
| break; | |||||
| case MapTargetDevice::kGpu: | |||||
| break; | |||||
| case MapTargetDevice::kDvpp: | |||||
| break; | |||||
| default: | |||||
| break; | |||||
| // If there is no existing map_job, we will create one. | |||||
| // map_job could be nullptr when we are at the first tensor op or when the target device of the prev op | |||||
| // is different with that of the current op. | |||||
| if (map_job == nullptr) { | |||||
| map_job = std::make_shared<CpuMapJob>(); | |||||
| } | } | ||||
| map_job->AddOperation(tfuncs_[i]); | |||||
| // Push map_job into worker_job if one of the two conditions is true: | // Push map_job into worker_job if one of the two conditions is true: | ||||
| // 1) It is the last tensor operation in tfuncs_ | // 1) It is the last tensor operation in tfuncs_ | ||||
| @@ -364,7 +351,7 @@ Status MapOp::ComputeColMap() { | |||||
| // Validating if each of the input_columns exists in the DataBuffer. | // Validating if each of the input_columns exists in the DataBuffer. | ||||
| Status MapOp::ValidateInColumns(const std::unordered_map<std::string, int32_t> &col_name_id_map) { | Status MapOp::ValidateInColumns(const std::unordered_map<std::string, int32_t> &col_name_id_map) { | ||||
| for (const auto &inCol : in_columns_) { | for (const auto &inCol : in_columns_) { | ||||
| bool found = col_name_id_map.find(inCol) != col_name_id_map.end() ? true : false; | |||||
| bool found = col_name_id_map.find(inCol) != col_name_id_map.end(); | |||||
| if (!found) { | if (!found) { | ||||
| std::string err_msg = "input column name: " + inCol + " doesn't exist in the dataset columns."; | std::string err_msg = "input column name: " + inCol + " doesn't exist in the dataset columns."; | ||||
| RETURN_STATUS_UNEXPECTED(err_msg); | RETURN_STATUS_UNEXPECTED(err_msg); | ||||
| @@ -30,7 +30,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| AlbumOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr), builder_schema_file_("") { | |||||
| AlbumOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) { | |||||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | ||||
| builder_num_workers_ = cfg->num_parallel_workers(); | builder_num_workers_ = cfg->num_parallel_workers(); | ||||
| builder_rows_per_buffer_ = cfg->rows_per_buffer(); | builder_rows_per_buffer_ = cfg->rows_per_buffer(); | ||||
| @@ -62,9 +62,8 @@ Status AlbumOp::Builder::Build(std::shared_ptr<AlbumOp> *ptr) { | |||||
| Status AlbumOp::Builder::SanityCheck() { | Status AlbumOp::Builder::SanityCheck() { | ||||
| Path dir(builder_dir_); | Path dir(builder_dir_); | ||||
| std::string err_msg; | std::string err_msg; | ||||
| err_msg += dir.IsDirectory() == false | |||||
| ? "Invalid parameter, Album path is invalid or not set, path: " + builder_dir_ + ".\n" | |||||
| : ""; | |||||
| err_msg += | |||||
| !dir.IsDirectory() ? "Invalid parameter, Album path is invalid or not set, path: " + builder_dir_ + ".\n" : ""; | |||||
| err_msg += builder_num_workers_ <= 0 ? "Invalid parameter, num_parallel_workers must be greater than 0, but got " + | err_msg += builder_num_workers_ <= 0 ? "Invalid parameter, num_parallel_workers must be greater than 0, but got " + | ||||
| std::to_string(builder_num_workers_) + ".\n" | std::to_string(builder_num_workers_) + ".\n" | ||||
| : ""; | : ""; | ||||
| @@ -97,8 +96,8 @@ bool StrComp(const std::string &a, const std::string &b) { | |||||
| // returns 1 if string "a" represent a numeric value less than string "b" | // returns 1 if string "a" represent a numeric value less than string "b" | ||||
| // the following will always return name, provided there is only one "." character in name | // the following will always return name, provided there is only one "." character in name | ||||
| // "." character is guaranteed to exist since the extension is checked befor this function call. | // "." character is guaranteed to exist since the extension is checked befor this function call. | ||||
| int64_t value_a = std::atoi(a.substr(1, a.find(".")).c_str()); | |||||
| int64_t value_b = std::atoi(b.substr(1, b.find(".")).c_str()); | |||||
| int64_t value_a = std::stoi(a.substr(1, a.find(".")).c_str()); | |||||
| int64_t value_b = std::stoi(b.substr(1, b.find(".")).c_str()); | |||||
| return value_a < value_b; | return value_a < value_b; | ||||
| } | } | ||||
| @@ -261,6 +260,7 @@ Status AlbumOp::LoadImageTensor(const std::string &image_file_path, uint32_t col | |||||
| RETURN_IF_NOT_OK(LoadEmptyTensor(col_num, row)); | RETURN_IF_NOT_OK(LoadEmptyTensor(col_num, row)); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| fs.close(); | |||||
| // Hack logic to replace png images with empty tensor | // Hack logic to replace png images with empty tensor | ||||
| Path file(image_file_path); | Path file(image_file_path); | ||||
| std::set<std::string> png_ext = {".png", ".PNG"}; | std::set<std::string> png_ext = {".png", ".PNG"}; | ||||
| @@ -387,7 +387,7 @@ Status AlbumOp::LoadIDTensor(const std::string &file, uint32_t col_num, TensorRo | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // hack to get the file name without extension, the 1 is to get rid of the backslash character | // hack to get the file name without extension, the 1 is to get rid of the backslash character | ||||
| int64_t image_id = std::atoi(file.substr(1, file.find(".")).c_str()); | |||||
| int64_t image_id = std::stoi(file.substr(1, file.find(".")).c_str()); | |||||
| TensorPtr id; | TensorPtr id; | ||||
| RETURN_IF_NOT_OK(Tensor::CreateScalar<int64_t>(image_id, &id)); | RETURN_IF_NOT_OK(Tensor::CreateScalar<int64_t>(image_id, &id)); | ||||
| MS_LOG(INFO) << "File ID " << image_id << "."; | MS_LOG(INFO) << "File ID " << image_id << "."; | ||||
| @@ -16,16 +16,16 @@ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_ALBUM_OP_H_ | #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_ALBUM_OP_H_ | ||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_ALBUM_OP_H_ | #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_ALBUM_OP_H_ | ||||
| #include <algorithm> | |||||
| #include <deque> | #include <deque> | ||||
| #include <map> | |||||
| #include <memory> | #include <memory> | ||||
| #include <queue> | #include <queue> | ||||
| #include <string> | |||||
| #include <algorithm> | |||||
| #include <map> | |||||
| #include <set> | #include <set> | ||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include <unordered_map> | |||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | #include "minddata/dataset/engine/data_buffer.h" | ||||
| #include "minddata/dataset/engine/data_schema.h" | #include "minddata/dataset/engine/data_schema.h" | ||||
| @@ -55,7 +55,7 @@ Status DistributedSamplerRT::InitSampler() { | |||||
| if (offset_ != -1 || !even_dist_) { | if (offset_ != -1 || !even_dist_) { | ||||
| if (offset_ == -1) offset_ = 0; | if (offset_ == -1) offset_ = 0; | ||||
| samples_per_buffer_ = (num_rows_ + offset_) / num_devices_; | samples_per_buffer_ = (num_rows_ + offset_) / num_devices_; | ||||
| int remainder = (num_rows_ + offset_) % num_devices_; | |||||
| int64_t remainder = (num_rows_ + offset_) % num_devices_; | |||||
| if (device_id_ < remainder) samples_per_buffer_++; | if (device_id_ < remainder) samples_per_buffer_++; | ||||
| if (device_id_ < offset_) samples_per_buffer_--; | if (device_id_ < offset_) samples_per_buffer_--; | ||||
| } else { | } else { | ||||
| @@ -63,7 +63,7 @@ Status DistributedSamplerRT::InitSampler() { | |||||
| samples_per_buffer_ = (num_rows_ + num_devices_ - 1) / num_devices_; // equals to ceil(num_rows/num_devices) | samples_per_buffer_ = (num_rows_ + num_devices_ - 1) / num_devices_; // equals to ceil(num_rows/num_devices) | ||||
| } | } | ||||
| samples_per_buffer_ = num_samples_ < samples_per_buffer_ ? num_samples_ : samples_per_buffer_; | samples_per_buffer_ = num_samples_ < samples_per_buffer_ ? num_samples_ : samples_per_buffer_; | ||||
| if (shuffle_ == true) { | |||||
| if (shuffle_) { | |||||
| shuffle_vec_.reserve(num_rows_); | shuffle_vec_.reserve(num_rows_); | ||||
| for (int64_t i = 0; i < num_rows_; i++) { | for (int64_t i = 0; i < num_rows_; i++) { | ||||
| shuffle_vec_.push_back(i); | shuffle_vec_.push_back(i); | ||||
| @@ -30,7 +30,7 @@ PKSamplerRT::PKSamplerRT(int64_t num_samples, int64_t val, bool shuffle, int64_t | |||||
| Status PKSamplerRT::InitSampler() { | Status PKSamplerRT::InitSampler() { | ||||
| labels_.reserve(label_to_ids_.size()); | labels_.reserve(label_to_ids_.size()); | ||||
| for (const auto &pair : label_to_ids_) { | for (const auto &pair : label_to_ids_) { | ||||
| if (pair.second.empty() == false) { | |||||
| if (!pair.second.empty()) { | |||||
| labels_.push_back(pair.first); | labels_.push_back(pair.first); | ||||
| } | } | ||||
| } | } | ||||
| @@ -76,6 +76,7 @@ Status PKSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| int64_t last_id = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_; | int64_t last_id = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_; | ||||
| RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, last_id - next_id_)); | RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, last_id - next_id_)); | ||||
| auto id_ptr = sample_ids->begin<int64_t>(); | auto id_ptr = sample_ids->begin<int64_t>(); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(samples_per_class_ != 0, "samples cannot be zero."); | |||||
| while (next_id_ < last_id && id_ptr != sample_ids->end<int64_t>()) { | while (next_id_ < last_id && id_ptr != sample_ids->end<int64_t>()) { | ||||
| int64_t cls_id = next_id_++ / samples_per_class_; | int64_t cls_id = next_id_++ / samples_per_class_; | ||||
| const std::vector<int64_t> &samples = label_to_ids_[labels_[cls_id]]; | const std::vector<int64_t> &samples = label_to_ids_[labels_[cls_id]]; | ||||
| @@ -32,8 +32,8 @@ class PKSamplerRT : public SamplerRT { // NOT YET FINISHED | |||||
| // @param int64_t val | // @param int64_t val | ||||
| // @param bool shuffle - shuffle all classIds or not, if true, classes may be 5,1,4,3,2 | // @param bool shuffle - shuffle all classIds or not, if true, classes may be 5,1,4,3,2 | ||||
| // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | ||||
| explicit PKSamplerRT(int64_t num_samples, int64_t val, bool shuffle, | |||||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||||
| PKSamplerRT(int64_t num_samples, int64_t val, bool shuffle, | |||||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||||
| // default destructor | // default destructor | ||||
| ~PKSamplerRT() = default; | ~PKSamplerRT() = default; | ||||
| @@ -28,8 +28,8 @@ RandomSamplerRT::RandomSamplerRT(int64_t num_samples, bool replacement, bool res | |||||
| seed_(GetSeed()), | seed_(GetSeed()), | ||||
| replacement_(replacement), | replacement_(replacement), | ||||
| next_id_(0), | next_id_(0), | ||||
| reshuffle_each_epoch_(reshuffle_each_epoch), | |||||
| dist(nullptr) {} | |||||
| dist(nullptr), | |||||
| reshuffle_each_epoch_(reshuffle_each_epoch) {} | |||||
| Status RandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | Status RandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | ||||
| if (next_id_ > num_samples_) { | if (next_id_ > num_samples_) { | ||||
| @@ -81,7 +81,7 @@ Status RandomSamplerRT::InitSampler() { | |||||
| samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; | samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; | ||||
| rnd_.seed(seed_); | rnd_.seed(seed_); | ||||
| if (replacement_ == false) { | |||||
| if (!replacement_) { | |||||
| shuffled_ids_.reserve(num_rows_); | shuffled_ids_.reserve(num_rows_); | ||||
| for (int64_t i = 0; i < num_rows_; i++) { | for (int64_t i = 0; i < num_rows_; i++) { | ||||
| shuffled_ids_.push_back(i); | shuffled_ids_.push_back(i); | ||||
| @@ -104,7 +104,7 @@ Status RandomSamplerRT::ResetSampler() { | |||||
| rnd_.seed(seed_); | rnd_.seed(seed_); | ||||
| if (replacement_ == false && reshuffle_each_epoch_) { | |||||
| if (!replacement_ && reshuffle_each_epoch_) { | |||||
| std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_); | std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_); | ||||
| } | } | ||||
| @@ -31,8 +31,8 @@ class RandomSamplerRT : public SamplerRT { | |||||
| // @param bool replacement - put he id back / or not after a sample | // @param bool replacement - put he id back / or not after a sample | ||||
| // @param reshuffle_each_epoch - T/F to reshuffle after epoch | // @param reshuffle_each_epoch - T/F to reshuffle after epoch | ||||
| // @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | // @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | ||||
| explicit RandomSamplerRT(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, | |||||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||||
| RandomSamplerRT(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, | |||||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||||
| // Destructor. | // Destructor. | ||||
| ~RandomSamplerRT() = default; | ~RandomSamplerRT() = default; | ||||
| @@ -50,7 +50,7 @@ class RandomSamplerRT : public SamplerRT { | |||||
| // @return - The error code return | // @return - The error code return | ||||
| Status ResetSampler() override; | Status ResetSampler() override; | ||||
| virtual void Print(std::ostream &out, bool show_all) const; | |||||
| void Print(std::ostream &out, bool show_all) const override; | |||||
| private: | private: | ||||
| uint32_t seed_; | uint32_t seed_; | ||||
| @@ -27,7 +27,7 @@ Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const { | |||||
| // Here, it is just a getter method to return the value. However, it is invalid if there is | // Here, it is just a getter method to return the value. However, it is invalid if there is | ||||
| // not a value set for this count, so generate a failure if that is the case. | // not a value set for this count, so generate a failure if that is the case. | ||||
| if (num == nullptr || num_rows_ == 0) { | if (num == nullptr || num_rows_ == 0) { | ||||
| RETURN_STATUS_UNEXPECTED("RandomAccessOp has not computed it's num rows yet."); | |||||
| RETURN_STATUS_UNEXPECTED("RandomAccessOp has not computed its num rows yet."); | |||||
| } | } | ||||
| (*num) = num_rows_; | (*num) = num_rows_; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -60,7 +60,7 @@ class SamplerRT { | |||||
| // @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0 | // @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0 | ||||
| // indicates that the sampler should produce the complete set of ids. | // indicates that the sampler should produce the complete set of ids. | ||||
| // @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call | // @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call | ||||
| explicit SamplerRT(int64_t num_samples, int64_t samples_per_buffer); | |||||
| SamplerRT(int64_t num_samples, int64_t samples_per_buffer); | |||||
| SamplerRT(const SamplerRT &s) : SamplerRT(s.num_samples_, s.samples_per_buffer_) {} | SamplerRT(const SamplerRT &s) : SamplerRT(s.num_samples_, s.samples_per_buffer_) {} | ||||
| @@ -21,7 +21,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| SequentialSamplerRT::SequentialSamplerRT(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer) | SequentialSamplerRT::SequentialSamplerRT(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer) | ||||
| : SamplerRT(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {} | |||||
| : SamplerRT(num_samples, samples_per_buffer), current_id_(start_index), start_index_(start_index), id_count_(0) {} | |||||
| Status SequentialSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | Status SequentialSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | ||||
| if (id_count_ > num_samples_) { | if (id_count_ > num_samples_) { | ||||
| @@ -30,8 +30,8 @@ class SequentialSamplerRT : public SamplerRT { | |||||
| // full amount of ids from the dataset | // full amount of ids from the dataset | ||||
| // @param start_index - The starting index value | // @param start_index - The starting index value | ||||
| // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | ||||
| explicit SequentialSamplerRT(int64_t num_samples, int64_t start_index, | |||||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||||
| SequentialSamplerRT(int64_t num_samples, int64_t start_index, | |||||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||||
| // Destructor. | // Destructor. | ||||
| ~SequentialSamplerRT() = default; | ~SequentialSamplerRT() = default; | ||||
| @@ -20,7 +20,6 @@ | |||||
| #include <random> | #include <random> | ||||
| #include <string> | #include <string> | ||||
| #include "minddata/dataset/core/config_manager.h" | |||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| #include "minddata/dataset/util/random.h" | #include "minddata/dataset/util/random.h" | ||||
| @@ -32,8 +32,8 @@ class SubsetRandomSamplerRT : public SamplerRT { | |||||
| // @param indices List of indices from where we will randomly draw samples. | // @param indices List of indices from where we will randomly draw samples. | ||||
| // @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer(). | // @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer(). | ||||
| // When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once. | // When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once. | ||||
| explicit SubsetRandomSamplerRT(int64_t num_samples, const std::vector<int64_t> &indices, | |||||
| std::int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||||
| SubsetRandomSamplerRT(int64_t num_samples, const std::vector<int64_t> &indices, | |||||
| std::int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||||
| // Destructor. | // Destructor. | ||||
| ~SubsetRandomSamplerRT() = default; | ~SubsetRandomSamplerRT() = default; | ||||
| @@ -42,9 +42,10 @@ Status WeightedRandomSamplerRT::InitSampler() { | |||||
| if (num_samples_ == 0 || num_samples_ > num_rows_) { | if (num_samples_ == 0 || num_samples_ > num_rows_) { | ||||
| num_samples_ = num_rows_; | num_samples_ = num_rows_; | ||||
| } | } | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && num_samples_, | |||||
| "Invalid parameter, num_samples & num_rows must be greater than 0, but got num_rows: " + | |||||
| std::to_string(num_rows_) + ", num_samples: " + std::to_string(num_samples_)); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||||
| num_rows_ > 0 && num_samples_, | |||||
| "Invalid parameter, num_samples and num_rows must be greater than 0, but got num_rows: " + | |||||
| std::to_string(num_rows_) + ", num_samples: " + std::to_string(num_samples_)); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, | CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, | ||||
| "Invalid parameter, samples_per_buffer must be greater than 0, but got " + | "Invalid parameter, samples_per_buffer must be greater than 0, but got " + | ||||
| std::to_string(samples_per_buffer_) + ".\n"); | std::to_string(samples_per_buffer_) + ".\n"); | ||||
| @@ -57,7 +58,7 @@ Status WeightedRandomSamplerRT::InitSampler() { | |||||
| } | } | ||||
| if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) { | if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) { | ||||
| RETURN_STATUS_UNEXPECTED( | RETURN_STATUS_UNEXPECTED( | ||||
| "Invalid parameter, without replacement, weights size must be greater than or equal to num_samples, " | |||||
| "Invalid parameter, without replacement, weight size must be greater than or equal to num_samples, " | |||||
| "but got weight size: " + | "but got weight size: " + | ||||
| std::to_string(weights_.size()) + ", num_samples: " + std::to_string(num_samples_)); | std::to_string(weights_.size()) + ", num_samples: " + std::to_string(num_samples_)); | ||||
| } | } | ||||
| @@ -122,7 +123,7 @@ Status WeightedRandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_b | |||||
| if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) { | if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) { | ||||
| RETURN_STATUS_UNEXPECTED( | RETURN_STATUS_UNEXPECTED( | ||||
| "Invalid parameter, without replacement, weights size must be greater than or equal to num_samples, " | |||||
| "Invalid parameter, without replacement, weight size must be greater than or equal to num_samples, " | |||||
| "but got weight size: " + | "but got weight size: " + | ||||
| std::to_string(weights_.size()) + ", num_samples: " + std::to_string(num_samples_)); | std::to_string(weights_.size()) + ", num_samples: " + std::to_string(num_samples_)); | ||||
| } | } | ||||
| @@ -18,8 +18,8 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include "minddata/dataset/util/status.h" | |||||
| #include "minddata/dataset/engine/datasetops/dataset_op.h" | #include "minddata/dataset/engine/datasetops/dataset_op.h" | ||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore::dataset { | namespace mindspore::dataset { | ||||