Browse Source

!12353 Move sampler IR code to engine/ir

From: @mhmotallebi
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
ced5575387
11 changed files with 852 additions and 792 deletions
  1. +2
    -0
      mindspore/ccsrc/minddata/dataset/CMakeLists.txt
  2. +1
    -1
      mindspore/ccsrc/minddata/dataset/api/datasets.cc
  3. +1
    -1
      mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/datasetops/source/samplers/bindings.cc
  4. +1
    -475
      mindspore/ccsrc/minddata/dataset/api/samplers.cc
  5. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/CMakeLists.txt
  6. +8
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/CMakeLists.txt
  7. +490
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.cc
  8. +344
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h
  9. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/serdes.h
  10. +2
    -315
      mindspore/ccsrc/minddata/dataset/include/samplers.h
  11. +1
    -0
      mindspore/lite/minddata/CMakeLists.txt

+ 2
- 0
mindspore/ccsrc/minddata/dataset/CMakeLists.txt View File

@@ -97,6 +97,7 @@ add_dependencies(text-ir-kernels core)
add_dependencies(cpp-API core)
add_dependencies(engine-ir-datasetops core)
add_dependencies(engine-ir-datasetops-source core)
add_dependencies(engine-ir-datasetops-source-samplers core)
add_dependencies(engine-ir-cache core)
add_dependencies(kernels-ir core)
add_dependencies(kernels-ir-data core)
@@ -135,6 +136,7 @@ set(submodules
$<TARGET_OBJECTS:cpp-API>
$<TARGET_OBJECTS:engine-ir-datasetops>
$<TARGET_OBJECTS:engine-ir-datasetops-source>
$<TARGET_OBJECTS:engine-ir-datasetops-source-samplers>
$<TARGET_OBJECTS:engine-ir-cache>
$<TARGET_OBJECTS:kernels-soft-dvpp-image>
$<TARGET_OBJECTS:soft-dvpp-utils>


+ 1
- 1
mindspore/ccsrc/minddata/dataset/api/datasets.cc View File

@@ -42,7 +42,7 @@
#endif

// Sampler headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"

#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"



+ 1
- 1
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/datasetops/source/samplers/bindings.cc View File

@@ -23,7 +23,7 @@
#include "minddata/dataset/callback/py_ds_callback.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"

namespace mindspore {
namespace dataset {


+ 1
- 475
mindspore/ccsrc/minddata/dataset/api/samplers.cc View File

@@ -15,69 +15,11 @@
*/

#include "minddata/dataset/include/samplers.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/subset_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h"

#ifndef ENABLE_ANDROID
#include "minddata/mindrecord/include/shard_distributed_sample.h"
#include "minddata/mindrecord/include/shard_operator.h"
#include "minddata/mindrecord/include/shard_pk_sample.h"
#include "minddata/mindrecord/include/shard_sample.h"
#include "minddata/mindrecord/include/shard_sequential_sample.h"
#include "minddata/mindrecord/include/shard_shuffle.h"
#include "minddata/dataset/util/random.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"

namespace mindspore {
namespace dataset {

#define RETURN_NULL_IF_ERROR(_s) \
do { \
Status __rc = (_s); \
if (__rc.IsError()) { \
MS_LOG(ERROR) << __rc; \
return nullptr; \
} \
} while (false)

// Constructor
SamplerObj::SamplerObj() {}

void SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> sampler) {
for (auto child : children_) {
auto sampler_rt = child->SamplerBuild();
sampler->AddChild(sampler_rt);
}
}

Status SamplerObj::AddChildSampler(std::shared_ptr<SamplerObj> child) {
if (child == nullptr) {
return Status::OK();
}

// Only samplers can be added, not any other DatasetOp.
std::shared_ptr<SamplerObj> sampler = std::dynamic_pointer_cast<SamplerObj>(child);
if (!sampler) {
RETURN_STATUS_UNEXPECTED("Cannot add child, child is not a sampler object.");
}

// Samplers can have at most 1 child.
if (!children_.empty()) {
RETURN_STATUS_UNEXPECTED("Cannot add child sampler, this sampler already has a child.");
}

children_.push_back(child);

return Status::OK();
}

/// Function to create a Distributed Sampler.
std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle,
int64_t num_samples, uint32_t seed, int64_t offset,
@@ -152,421 +94,5 @@ std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<doub
return sampler;
}

/* ####################################### Derived Sampler classes ################################# */

// DistributedSampler
DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
uint32_t seed, int64_t offset, bool even_dist)
: num_shards_(num_shards),
shard_id_(shard_id),
shuffle_(shuffle),
num_samples_(num_samples),
seed_(seed),
offset_(offset),
even_dist_(even_dist) {
// Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
// is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
// 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
// PreBuildSampler is phased out, this can be cleaned up.
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
}

Status DistributedSamplerObj::ValidateParams() {
if (num_shards_ <= 0) {
RETURN_STATUS_UNEXPECTED("DistributedSampler: num_shards must be greater than 0, but got: " +
std::to_string(num_shards_));
}

if (shard_id_ < 0 || shard_id_ >= num_shards_) {
RETURN_STATUS_UNEXPECTED("DistributedSampler: shard_id must be in range [0, " + std::to_string(num_shards_) +
"), but got: " + std::to_string(shard_id_));
}

if (num_samples_ < 0) {
RETURN_STATUS_UNEXPECTED("DistributedSampler: num_samples must be greater than or equal to 0, but got: " +
std::to_string(num_samples_));
}

if (offset_ > num_shards_) {
RETURN_STATUS_UNEXPECTED("DistributedSampler: offset must be no more than num_shards(" +
std::to_string(num_shards_) + "), but got: " + std::to_string(offset_));
}

return Status::OK();
}

std::shared_ptr<SamplerRT> DistributedSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
offset_, even_dist_);
BuildChildren(sampler);
return sampler;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler = std::make_shared<mindrecord::ShardDistributedSample>(num_shards_, shard_id_, shuffle_, seed_,
num_samples_, offset_);
return mind_sampler;
}
#endif

Status DistributedSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "DistributedSampler";
args["num_shards"] = num_shards_;
args["shard_id"] = shard_id_;
args["shuffle"] = shuffle_;
args["num_samples"] = num_samples_;
args["offset"] = offset_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}

// PKSampler
PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples)
: num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {}

Status PKSamplerObj::ValidateParams() {
if (num_val_ <= 0) {
RETURN_STATUS_UNEXPECTED("PKSampler: num_val must be greater than 0, but got: " + std::to_string(num_val_));
}

if (num_samples_ < 0) {
RETURN_STATUS_UNEXPECTED("PKSampler: num_samples must be greater than or equal to 0, but got: " +
std::to_string(num_samples_));
}
return Status::OK();
}

Status PKSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "PKSampler";
args["num_val"] = num_val_;
args["shuffle"] = shuffle_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}

std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_);
BuildChildren(sampler);
return sampler;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
std::shared_ptr<mindrecord::ShardOperator> mind_sampler;
if (shuffle_ == true) {
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, std::numeric_limits<int64_t>::max(),
GetSeed(), num_samples_);
} else {
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, num_samples_);
}

return mind_sampler;
}
#endif

// PreBuiltOperation
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) : sp_(std::move(sampler)) {}

#ifndef ENABLE_ANDROID
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler)
: sp_minddataset_(std::move(sampler)) {}
#endif

Status PreBuiltSamplerObj::ValidateParams() { return Status::OK(); }

std::shared_ptr<SamplerRT> PreBuiltSamplerObj::SamplerBuild() {
BuildChildren(sp_);
return sp_;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; }
#endif

std::shared_ptr<SamplerObj> PreBuiltSamplerObj::SamplerCopy() {
#ifndef ENABLE_ANDROID
if (sp_minddataset_ != nullptr) {
auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_minddataset_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}
#endif
auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}

Status PreBuiltSamplerObj::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sp_->to_json(out_json));
return Status::OK();
}

// RandomSampler
RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch)
: replacement_(replacement), num_samples_(num_samples), reshuffle_each_epoch_(reshuffle_each_epoch) {}

Status RandomSamplerObj::ValidateParams() {
if (num_samples_ < 0) {
RETURN_STATUS_UNEXPECTED("RandomSampler: num_samples must be greater than or equal to 0, but got: " +
std::to_string(num_samples_));
}
return Status::OK();
}

Status RandomSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "RandomSampler";
args["replacement"] = replacement_;
args["num_samples"] = num_samples_;
args["reshuffle_each_epoch"] = reshuffle_each_epoch_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}

std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch_);
BuildChildren(sampler);
return sampler;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler =
std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples_, replacement_, reshuffle_each_epoch_);

return mind_sampler;
}
#endif

// SequentialSampler
SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples)
: start_index_(start_index), num_samples_(num_samples) {}

Status SequentialSamplerObj::ValidateParams() {
if (num_samples_ < 0) {
RETURN_STATUS_UNEXPECTED("SequentialSampler: num_samples must be greater than or equal to 0, but got: " +
std::to_string(num_samples_));
}

if (start_index_ < 0) {
RETURN_STATUS_UNEXPECTED("SequentialSampler: start_index_ must be greater than or equal to 0, but got: " +
std::to_string(start_index_));
}

return Status::OK();
}

Status SequentialSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "SequentialSampler";
args["start_index"] = start_index_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}

std::shared_ptr<SamplerRT> SequentialSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_);
BuildChildren(sampler);
return sampler;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> SequentialSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler = std::make_shared<mindrecord::ShardSequentialSample>(num_samples_, start_index_);

return mind_sampler;
}
#endif

// SubsetSampler
SubsetSamplerObj::SubsetSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
: indices_(std::move(indices)), num_samples_(num_samples) {}

Status SubsetSamplerObj::ValidateParams() {
if (num_samples_ < 0) {
RETURN_STATUS_UNEXPECTED("SubsetRandomSampler: num_samples must be greater than or equal to 0, but got: " +
std::to_string(num_samples_));
}

return Status::OK();
}

std::shared_ptr<SamplerRT> SubsetSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::SubsetSamplerRT>(num_samples_, indices_);
BuildChildren(sampler);
return sampler;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> SubsetSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler = std::make_shared<mindrecord::ShardSample>(indices_);

return mind_sampler;
}
#endif
Status SubsetSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "SubsetSampler";
args["indices"] = indices_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}

// SubsetRandomSampler
SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
: SubsetSamplerObj(std::move(indices), num_samples) {}

std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_);
BuildChildren(sampler);
return sampler;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler = std::make_shared<mindrecord::ShardSample>(indices_, GetSeed());

return mind_sampler;
}
#endif

Status SubsetRandomSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "SubsetRandomSampler";
args["indices"] = indices_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}

// WeightedRandomSampler
WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
: weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}

Status WeightedRandomSamplerObj::ValidateParams() {
if (weights_.empty()) {
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not be empty");
}
int32_t zero_elem = 0;
for (int32_t i = 0; i < weights_.size(); ++i) {
if (weights_[i] < 0) {
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not contain negative number, got: " +
std::to_string(weights_[i]));
}
if (weights_[i] == 0.0) {
zero_elem++;
}
}
if (zero_elem == weights_.size()) {
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: elements of weights vector must not be all zero");
}
if (num_samples_ < 0) {
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: num_samples must be greater than or equal to 0, but got: " +
std::to_string(num_samples_));
}
return Status::OK();
}

Status WeightedRandomSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "WeightedRandomSampler";
args["weights"] = weights_;
args["num_samples"] = num_samples_;
args["replacement"] = replacement_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}

std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::SamplerBuild() {
auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_);
BuildChildren(sampler);
return sampler;
}

} // namespace dataset
} // namespace mindspore

+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/CMakeLists.txt View File

@@ -1,5 +1,6 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_subdirectory(samplers)

set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
album_node.cc


+ 8
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/CMakeLists.txt View File

@@ -0,0 +1,8 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)

set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SRC_FILES
samplers_ir.cc
)

add_library(engine-ir-datasetops-source-samplers OBJECT ${DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SRC_FILES})

+ 490
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.cc View File

@@ -0,0 +1,490 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"

#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/subset_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"

#ifndef ENABLE_ANDROID
#include "minddata/dataset/util/random.h"
#include "minddata/mindrecord/include/shard_distributed_sample.h"
#include "minddata/mindrecord/include/shard_operator.h"
#include "minddata/mindrecord/include/shard_pk_sample.h"
#include "minddata/mindrecord/include/shard_sample.h"
#include "minddata/mindrecord/include/shard_sequential_sample.h"
#include "minddata/mindrecord/include/shard_shuffle.h"
#endif

namespace mindspore {
namespace dataset {

// Constructor
SamplerObj::SamplerObj() {}

void SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> sampler) {
for (auto child : children_) {
auto sampler_rt = child->SamplerBuild();
sampler->AddChild(sampler_rt);
}
}

Status SamplerObj::AddChildSampler(std::shared_ptr<SamplerObj> child) {
if (child == nullptr) {
return Status::OK();
}

// Only samplers can be added, not any other DatasetOp.
std::shared_ptr<SamplerObj> sampler = std::dynamic_pointer_cast<SamplerObj>(child);
if (!sampler) {
RETURN_STATUS_UNEXPECTED("Cannot add child, child is not a sampler object.");
}

// Samplers can have at most 1 child.
if (!children_.empty()) {
RETURN_STATUS_UNEXPECTED("Cannot add child sampler, this sampler already has a child.");
}

children_.push_back(child);

return Status::OK();
}

/* ####################################### Derived Sampler classes ################################# */

// DistributedSampler
DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
uint32_t seed, int64_t offset, bool even_dist)
: num_shards_(num_shards),
shard_id_(shard_id),
shuffle_(shuffle),
num_samples_(num_samples),
seed_(seed),
offset_(offset),
even_dist_(even_dist) {
// Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
// is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
// 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
// PreBuildSampler is phased out, this can be cleaned up.
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
}

Status DistributedSamplerObj::ValidateParams() {
if (num_shards_ <= 0) {
RETURN_STATUS_UNEXPECTED("DistributedSampler: num_shards must be greater than 0, but got: " +
std::to_string(num_shards_));
}

if (shard_id_ < 0 || shard_id_ >= num_shards_) {
RETURN_STATUS_UNEXPECTED("DistributedSampler: shard_id must be in range [0, " + std::to_string(num_shards_) +
"), but got: " + std::to_string(shard_id_));
}

if (num_samples_ < 0) {
RETURN_STATUS_UNEXPECTED("DistributedSampler: num_samples must be greater than or equal to 0, but got: " +
std::to_string(num_samples_));
}

if (offset_ > num_shards_) {
RETURN_STATUS_UNEXPECTED("DistributedSampler: offset must be no more than num_shards(" +
std::to_string(num_shards_) + "), but got: " + std::to_string(offset_));
}

return Status::OK();
}

std::shared_ptr<SamplerRT> DistributedSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
offset_, even_dist_);
BuildChildren(sampler);
return sampler;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler = std::make_shared<mindrecord::ShardDistributedSample>(num_shards_, shard_id_, shuffle_, seed_,
num_samples_, offset_);
return mind_sampler;
}
#endif

Status DistributedSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "DistributedSampler";
args["num_shards"] = num_shards_;
args["shard_id"] = shard_id_;
args["shuffle"] = shuffle_;
args["num_samples"] = num_samples_;
args["offset"] = offset_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}

// PKSampler
PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples)
: num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {}

Status PKSamplerObj::ValidateParams() {
if (num_val_ <= 0) {
RETURN_STATUS_UNEXPECTED("PKSampler: num_val must be greater than 0, but got: " + std::to_string(num_val_));
}

if (num_samples_ < 0) {
RETURN_STATUS_UNEXPECTED("PKSampler: num_samples must be greater than or equal to 0, but got: " +
std::to_string(num_samples_));
}
return Status::OK();
}

Status PKSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "PKSampler";
args["num_val"] = num_val_;
args["shuffle"] = shuffle_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}

std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_);
BuildChildren(sampler);
return sampler;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
std::shared_ptr<mindrecord::ShardOperator> mind_sampler;
if (shuffle_ == true) {
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, std::numeric_limits<int64_t>::max(),
GetSeed(), num_samples_);
} else {
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, num_samples_);
}

return mind_sampler;
}
#endif

// PreBuiltOperation
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) : sp_(std::move(sampler)) {}

#ifndef ENABLE_ANDROID
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler)
: sp_minddataset_(std::move(sampler)) {}
#endif

Status PreBuiltSamplerObj::ValidateParams() { return Status::OK(); }

std::shared_ptr<SamplerRT> PreBuiltSamplerObj::SamplerBuild() {
BuildChildren(sp_);
return sp_;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; }
#endif

std::shared_ptr<SamplerObj> PreBuiltSamplerObj::SamplerCopy() {
#ifndef ENABLE_ANDROID
if (sp_minddataset_ != nullptr) {
auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_minddataset_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}
#endif
auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}

Status PreBuiltSamplerObj::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sp_->to_json(out_json));
return Status::OK();
}

// RandomSampler
RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch)
: replacement_(replacement), num_samples_(num_samples), reshuffle_each_epoch_(reshuffle_each_epoch) {}

Status RandomSamplerObj::ValidateParams() {
if (num_samples_ < 0) {
RETURN_STATUS_UNEXPECTED("RandomSampler: num_samples must be greater than or equal to 0, but got: " +
std::to_string(num_samples_));
}
return Status::OK();
}

Status RandomSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "RandomSampler";
args["replacement"] = replacement_;
args["num_samples"] = num_samples_;
args["reshuffle_each_epoch"] = reshuffle_each_epoch_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}

std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch_);
BuildChildren(sampler);
return sampler;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler =
std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples_, replacement_, reshuffle_each_epoch_);

return mind_sampler;
}
#endif

// SequentialSampler
SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples)
: start_index_(start_index), num_samples_(num_samples) {}

Status SequentialSamplerObj::ValidateParams() {
if (num_samples_ < 0) {
RETURN_STATUS_UNEXPECTED("SequentialSampler: num_samples must be greater than or equal to 0, but got: " +
std::to_string(num_samples_));
}

if (start_index_ < 0) {
RETURN_STATUS_UNEXPECTED("SequentialSampler: start_index_ must be greater than or equal to 0, but got: " +
std::to_string(start_index_));
}

return Status::OK();
}

Status SequentialSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "SequentialSampler";
args["start_index"] = start_index_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}

std::shared_ptr<SamplerRT> SequentialSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_);
BuildChildren(sampler);
return sampler;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> SequentialSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler = std::make_shared<mindrecord::ShardSequentialSample>(num_samples_, start_index_);

return mind_sampler;
}
#endif

// SubsetSampler
SubsetSamplerObj::SubsetSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
: indices_(std::move(indices)), num_samples_(num_samples) {}

Status SubsetSamplerObj::ValidateParams() {
if (num_samples_ < 0) {
RETURN_STATUS_UNEXPECTED("SubsetRandomSampler: num_samples must be greater than or equal to 0, but got: " +
std::to_string(num_samples_));
}

return Status::OK();
}

std::shared_ptr<SamplerRT> SubsetSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::SubsetSamplerRT>(num_samples_, indices_);
BuildChildren(sampler);
return sampler;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> SubsetSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler = std::make_shared<mindrecord::ShardSample>(indices_);

return mind_sampler;
}
#endif
Status SubsetSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "SubsetSampler";
args["indices"] = indices_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}

// SubsetRandomSampler
SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
: SubsetSamplerObj(std::move(indices), num_samples) {}

std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_);
BuildChildren(sampler);
return sampler;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler = std::make_shared<mindrecord::ShardSample>(indices_, GetSeed());

return mind_sampler;
}
#endif

Status SubsetRandomSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "SubsetRandomSampler";
args["indices"] = indices_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}

// WeightedRandomSampler
WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
: weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}

Status WeightedRandomSamplerObj::ValidateParams() {
if (weights_.empty()) {
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not be empty");
}
int32_t zero_elem = 0;
for (int32_t i = 0; i < weights_.size(); ++i) {
if (weights_[i] < 0) {
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not contain negative number, got: " +
std::to_string(weights_[i]));
}
if (weights_[i] == 0.0) {
zero_elem++;
}
}
if (zero_elem == weights_.size()) {
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: elements of weights vector must not be all zero");
}
if (num_samples_ < 0) {
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: num_samples must be greater than or equal to 0, but got: " +
std::to_string(num_samples_));
}
return Status::OK();
}

Status WeightedRandomSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "WeightedRandomSampler";
args["weights"] = weights_;
args["num_samples"] = num_samples_;
args["replacement"] = replacement_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}

std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::SamplerBuild() {
auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_);
BuildChildren(sampler);
return sampler;
}

} // namespace dataset
} // namespace mindspore

+ 344
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h View File

@@ -0,0 +1,344 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SAMPLERS_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SAMPLERS_IR_H_

#include <limits>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <nlohmann/json.hpp>

#include "include/api/status.h"
#ifndef ENABLE_ANDROID
#include "minddata/mindrecord/include/shard_operator.h"
#endif

namespace mindspore {
namespace dataset {

// Internal Sampler class forward declaration
class SamplerRT;

class SamplerObj {
public:
/// \brief Constructor
SamplerObj();

/// \brief Destructor
~SamplerObj() = default;

/// \brief Pure virtual function for derived class to implement parameters validation
/// \return The Status code of the function. It returns OK status if parameters are valid.
virtual Status ValidateParams() = 0;

/// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object
/// \return Shared pointers to the newly created Sampler
virtual std::shared_ptr<SamplerRT> SamplerBuild() = 0;

/// \brief Pure virtual function to copy a SamplerObj class
/// \return Shared pointers to the newly copied SamplerObj
virtual std::shared_ptr<SamplerObj> SamplerCopy() = 0;

/// \brief Function for derived class to get the shard id of sampler
/// \return The shard id of the derived sampler
virtual int64_t ShardId() { return 0; }

/// \brief Adds a child to the sampler
/// \param[in] child The sampler to be added as child
/// \return the Status code returned
Status AddChildSampler(std::shared_ptr<SamplerObj> child);

virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }

std::vector<std::shared_ptr<SamplerObj>> GetChild() { return children_; }

#ifndef ENABLE_ANDROID
/// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object,
/// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler
/// \return Shared pointers to the newly created Sampler
virtual std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() { return nullptr; }
#endif

protected:
/// \brief A function that calls build on the children of this sampler
/// \param[in] sampler The samplerRT object built from this sampler
void BuildChildren(std::shared_ptr<SamplerRT> sampler);

std::vector<std::shared_ptr<SamplerObj>> children_;
};

/* ####################################### Derived Sampler classes ################################# */
class DistributedSamplerObj : public SamplerObj {
public:
DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, uint32_t seed,
int64_t offset, bool even_dist);

~DistributedSamplerObj() = default;

std::shared_ptr<SamplerRT> SamplerBuild() override;

std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_,
offset_, even_dist_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif

/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;

Status ValidateParams() override;

/// \brief Function to get the shard id of sampler
/// \return The shard id of sampler
int64_t ShardId() override { return shard_id_; }

private:
int64_t num_shards_;
int64_t shard_id_;
bool shuffle_;
int64_t num_samples_;
uint32_t seed_;
int64_t offset_;
bool even_dist_;
};

class PKSamplerObj : public SamplerObj {
public:
PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples);

~PKSamplerObj() = default;

std::shared_ptr<SamplerRT> SamplerBuild() override;

std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif

/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;

Status ValidateParams() override;

private:
int64_t num_val_;
bool shuffle_;
int64_t num_samples_;
};

class PreBuiltSamplerObj : public SamplerObj {
public:
explicit PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler);
#ifndef ENABLE_ANDROID
explicit PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler);
#endif

~PreBuiltSamplerObj() = default;

std::shared_ptr<SamplerRT> SamplerBuild() override;

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif

std::shared_ptr<SamplerObj> SamplerCopy() override;

Status ValidateParams() override;

Status to_json(nlohmann::json *out_json) override;

private:
std::shared_ptr<SamplerRT> sp_;
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> sp_minddataset_;
#endif
};

class RandomSamplerObj : public SamplerObj {
public:
RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch = true);

~RandomSamplerObj() = default;

std::shared_ptr<SamplerRT> SamplerBuild() override;

std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_, reshuffle_each_epoch_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif

/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;

Status ValidateParams() override;

private:
bool replacement_;
int64_t num_samples_;
bool reshuffle_each_epoch_;
};

class SequentialSamplerObj : public SamplerObj {
public:
SequentialSamplerObj(int64_t start_index, int64_t num_samples);

~SequentialSamplerObj() = default;

std::shared_ptr<SamplerRT> SamplerBuild() override;

std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif

/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;

Status ValidateParams() override;

private:
int64_t start_index_;
int64_t num_samples_;
};

class SubsetSamplerObj : public SamplerObj {
public:
SubsetSamplerObj(std::vector<int64_t> indices, int64_t num_samples);

~SubsetSamplerObj() = default;

std::shared_ptr<SamplerRT> SamplerBuild() override;

std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<SubsetSamplerObj>(indices_, num_samples_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif

/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;

Status ValidateParams() override;

protected:
const std::vector<int64_t> indices_;
int64_t num_samples_;
};

class SubsetRandomSamplerObj : public SubsetSamplerObj {
public:
SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples);

~SubsetRandomSamplerObj() = default;

Status to_json(nlohmann::json *out_json) override;

std::shared_ptr<SamplerRT> SamplerBuild() override;

std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif

private:
};

class WeightedRandomSamplerObj : public SamplerObj {
public:
explicit WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples = 0, bool replacement = true);

~WeightedRandomSamplerObj() = default;

std::shared_ptr<SamplerRT> SamplerBuild() override;

std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}

/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;

Status ValidateParams() override;

private:
const std::vector<double> weights_;
int64_t num_samples_;
bool replacement_;
};

} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SAMPLERS_IR_H_

+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/serdes.h View File

@@ -16,6 +16,7 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_SERDES_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_SERDES_H_

#include <fstream>
#include <memory>
#include <string>
#include <vector>


+ 2
- 315
mindspore/ccsrc/minddata/dataset/include/samplers.h View File

@@ -18,72 +18,14 @@
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_

#include <memory>
#include <string>
#include <vector>
#include <nlohmann/json.hpp>

#include "include/api/status.h"
#ifndef ENABLE_ANDROID
#include "minddata/mindrecord/include/shard_column.h"
#include "minddata/mindrecord/include/shard_error.h"
#include "minddata/mindrecord/include/shard_operator.h"
#include "minddata/mindrecord/include/shard_reader.h"
#endif
// FIXME - This internal IR header will be removed when external API classes are provided
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"

namespace mindspore {
namespace dataset {

// Internal Sampler class forward declaration
class SamplerRT;

class SamplerObj {
public:
/// \brief Constructor
SamplerObj();

/// \brief Destructor
~SamplerObj() = default;

/// \brief Pure virtual function for derived class to implement parameters validation
/// \return The Status code of the function. It returns OK status if parameters are valid.
virtual Status ValidateParams() = 0;

/// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object
/// \return Shared pointers to the newly created Sampler
virtual std::shared_ptr<SamplerRT> SamplerBuild() = 0;

/// \brief Pure virtual function to copy a SamplerObj class
/// \return Shared pointers to the newly copied SamplerObj
virtual std::shared_ptr<SamplerObj> SamplerCopy() = 0;

/// \brief Function for derived class to get the shard id of sampler
/// \return The shard id of the derived sampler
virtual int64_t ShardId() { return 0; }

/// \brief Adds a child to the sampler
/// \param[in] child The sampler to be added as child
/// \return the Status code returned
Status AddChildSampler(std::shared_ptr<SamplerObj> child);

virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }

std::vector<std::shared_ptr<SamplerObj>> GetChild() { return children_; }

#ifndef ENABLE_ANDROID
/// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object,
/// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler
/// \return Shared pointers to the newly created Sampler
virtual std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() { return nullptr; }
#endif

protected:
/// \brief A function that calls build on the children of this sampler
/// \param[in] sampler The samplerRT object built from this sampler
void BuildChildren(std::shared_ptr<SamplerRT> sampler);

std::vector<std::shared_ptr<SamplerObj>> children_;
};

class DistributedSamplerObj;
class PKSamplerObj;
class PreBuiltSamplerObj;
@@ -155,261 +97,6 @@ std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t>
std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<double> weights, int64_t num_samples = 0,
bool replacement = true);

/* ####################################### Derived Sampler classes ################################# */
class DistributedSamplerObj : public SamplerObj {
public:
DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, uint32_t seed,
int64_t offset, bool even_dist);

~DistributedSamplerObj() = default;

std::shared_ptr<SamplerRT> SamplerBuild() override;

std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_,
offset_, even_dist_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif

/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;

Status ValidateParams() override;

/// \brief Function to get the shard id of sampler
/// \return The shard id of sampler
int64_t ShardId() override { return shard_id_; }

private:
int64_t num_shards_;
int64_t shard_id_;
bool shuffle_;
int64_t num_samples_;
uint32_t seed_;
int64_t offset_;
bool even_dist_;
};

class PKSamplerObj : public SamplerObj {
public:
PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples);

~PKSamplerObj() = default;

std::shared_ptr<SamplerRT> SamplerBuild() override;

std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif

/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;

Status ValidateParams() override;

private:
int64_t num_val_;
bool shuffle_;
int64_t num_samples_;
};

class PreBuiltSamplerObj : public SamplerObj {
public:
explicit PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler);
#ifndef ENABLE_ANDROID
explicit PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler);
#endif

~PreBuiltSamplerObj() = default;

std::shared_ptr<SamplerRT> SamplerBuild() override;

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif

std::shared_ptr<SamplerObj> SamplerCopy() override;

Status ValidateParams() override;

Status to_json(nlohmann::json *out_json) override;

private:
std::shared_ptr<SamplerRT> sp_;
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> sp_minddataset_;
#endif
};

class RandomSamplerObj : public SamplerObj {
public:
RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch = true);

~RandomSamplerObj() = default;

std::shared_ptr<SamplerRT> SamplerBuild() override;

std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_, reshuffle_each_epoch_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif

/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;

Status ValidateParams() override;

private:
bool replacement_;
int64_t num_samples_;
bool reshuffle_each_epoch_;
};

class SequentialSamplerObj : public SamplerObj {
public:
SequentialSamplerObj(int64_t start_index, int64_t num_samples);

~SequentialSamplerObj() = default;

std::shared_ptr<SamplerRT> SamplerBuild() override;

std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif

/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;

Status ValidateParams() override;

private:
int64_t start_index_;
int64_t num_samples_;
};

class SubsetSamplerObj : public SamplerObj {
public:
SubsetSamplerObj(std::vector<int64_t> indices, int64_t num_samples);

~SubsetSamplerObj() = default;

std::shared_ptr<SamplerRT> SamplerBuild() override;

std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<SubsetSamplerObj>(indices_, num_samples_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif

/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;

Status ValidateParams() override;

protected:
const std::vector<int64_t> indices_;
int64_t num_samples_;
};

class SubsetRandomSamplerObj : public SubsetSamplerObj {
public:
SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples);

~SubsetRandomSamplerObj() = default;

Status to_json(nlohmann::json *out_json) override;

std::shared_ptr<SamplerRT> SamplerBuild() override;

std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif

private:
};

class WeightedRandomSamplerObj : public SamplerObj {
public:
explicit WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples = 0, bool replacement = true);

~WeightedRandomSamplerObj() = default;

std::shared_ptr<SamplerRT> SamplerBuild() override;

std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}

/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;

Status ValidateParams() override;

private:
const std::vector<double> weights_;
int64_t num_samples_;
bool replacement_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_

+ 1
- 0
mindspore/lite/minddata/CMakeLists.txt View File

@@ -136,6 +136,7 @@ if(BUILD_MINDDATA STREQUAL "full")
${MINDDATA_DIR}/engine/ir/datasetops/shuffle_node.cc
${MINDDATA_DIR}/engine/ir/datasetops/source/album_node.cc
${MINDDATA_DIR}/engine/ir/datasetops/source/mnist_node.cc
${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/samplers_ir.cc
${MINDDATA_DIR}/engine/datasetops/dataset_op.cc
${MINDDATA_DIR}/engine/datasetops/repeat_op.cc
${MINDDATA_DIR}/engine/datasetops/epoch_ctrl_op.cc


Loading…
Cancel
Save