Browse Source

!13203 Add const to C++ Sampler IR functions

From: @mhmotallebi
Reviewed-by: @nsyca,@robingrosman
Signed-off-by: @robingrosman
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
654be6216e
4 changed files with 23 additions and 24 deletions
  1. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.cc
  2. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.h
  3. +10
    -11
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.cc
  4. +11
    -11
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h

+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.cc View File

@@ -74,7 +74,7 @@ std::shared_ptr<SamplerObj> CacheLookupNode::SamplerCopy() {
return std::static_pointer_cast<SamplerObj>(lookup_node_copy_);
}

Status CacheLookupNode::SamplerBuild(std::shared_ptr<SamplerRT> *out) {
Status CacheLookupNode::SamplerBuild(std::shared_ptr<SamplerRT> *const out) {
// Runtime cache lookup op should already been built, so we just return it here
auto lookup_op = std::dynamic_pointer_cast<CacheLookupOp>(lookup_op_);
*out = std::shared_ptr<SamplerRT>(lookup_op);


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.h View File

@@ -50,7 +50,7 @@ class CacheLookupNode : public DatasetNode, public SamplerObj {
/// \brief a base class override function to convert a SamplerObj class into a runtime sampler object
/// \param[out] out Shared pointer to the newly created Sampler
/// \return The Status code of the function. It returns OK status if sampler is created successfully.
Status SamplerBuild(std::shared_ptr<SamplerRT> *out) override;
Status SamplerBuild(std::shared_ptr<SamplerRT> *const out) override;

/// \brief a base class override function to copy a SamplerObj class
/// \return Shared pointers to the newly copied SamplerObj


+ 10
- 11
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.cc View File

@@ -42,7 +42,7 @@ namespace dataset {
// Constructor
SamplerObj::SamplerObj() {}

Status SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> *sampler) {
Status SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> *const sampler) {
for (auto child : children_) {
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(child->SamplerBuild(&sampler_rt));
@@ -133,7 +133,7 @@ std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDa
}
#endif

Status DistributedSamplerObj::to_json(nlohmann::json *out_json) {
Status DistributedSamplerObj::to_json(nlohmann::json *const out_json) {
nlohmann::json args;
args["sampler_name"] = "DistributedSampler";
args["num_shards"] = num_shards_;
@@ -170,7 +170,7 @@ Status PKSamplerObj::ValidateParams() {
return Status::OK();
}

Status PKSamplerObj::to_json(nlohmann::json *out_json) {
Status PKSamplerObj::to_json(nlohmann::json *const out_json) {
nlohmann::json args;
args["sampler_name"] = "PKSampler";
args["num_val"] = num_val_;
@@ -222,13 +222,12 @@ PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator

Status PreBuiltSamplerObj::ValidateParams() { return Status::OK(); }

Status PreBuiltSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
Status PreBuiltSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *const sampler) {
Status s = BuildChildren(&sp_);
if (s.IsOk())
*sampler = sp_;
else
*sampler = nullptr;
// FIXME: what to do with sp_ if status is not OK?
return s;
}

@@ -253,7 +252,7 @@ std::shared_ptr<SamplerObj> PreBuiltSamplerObj::SamplerCopy() {
return sampler;
}

Status PreBuiltSamplerObj::to_json(nlohmann::json *out_json) {
Status PreBuiltSamplerObj::to_json(nlohmann::json *const out_json) {
RETURN_IF_NOT_OK(sp_->to_json(out_json));
return Status::OK();
}
@@ -270,7 +269,7 @@ Status RandomSamplerObj::ValidateParams() {
return Status::OK();
}

Status RandomSamplerObj::to_json(nlohmann::json *out_json) {
Status RandomSamplerObj::to_json(nlohmann::json *const out_json) {
nlohmann::json args;
args["sampler_name"] = "RandomSampler";
args["replacement"] = replacement_;
@@ -325,7 +324,7 @@ Status SequentialSamplerObj::ValidateParams() {
return Status::OK();
}

Status SequentialSamplerObj::to_json(nlohmann::json *out_json) {
Status SequentialSamplerObj::to_json(nlohmann::json *const out_json) {
nlohmann::json args;
args["sampler_name"] = "SequentialSampler";
args["start_index"] = start_index_;
@@ -389,7 +388,7 @@ std::shared_ptr<mindrecord::ShardOperator> SubsetSamplerObj::BuildForMindDataset
return mind_sampler;
}
#endif
Status SubsetSamplerObj::to_json(nlohmann::json *out_json) {
Status SubsetSamplerObj::to_json(nlohmann::json *const out_json) {
nlohmann::json args;
args["sampler_name"] = "SubsetSampler";
args["indices"] = indices_;
@@ -428,7 +427,7 @@ std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindD
}
#endif

Status SubsetRandomSamplerObj::to_json(nlohmann::json *out_json) {
Status SubsetRandomSamplerObj::to_json(nlohmann::json *const out_json) {
nlohmann::json args;
args["sampler_name"] = "SubsetRandomSampler";
args["indices"] = indices_;
@@ -474,7 +473,7 @@ Status WeightedRandomSamplerObj::ValidateParams() {
return Status::OK();
}

Status WeightedRandomSamplerObj::to_json(nlohmann::json *out_json) {
Status WeightedRandomSamplerObj::to_json(nlohmann::json *const out_json) {
nlohmann::json args;
args["sampler_name"] = "WeightedRandomSampler";
args["weights"] = weights_;


+ 11
- 11
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h View File

@@ -65,7 +65,7 @@ class SamplerObj {
/// \return the Status code returned
Status AddChildSampler(std::shared_ptr<SamplerObj> child);

virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
virtual Status to_json(nlohmann::json *const out_json) { return Status::OK(); }

std::vector<std::shared_ptr<SamplerObj>> GetChild() { return children_; }

@@ -80,7 +80,7 @@ class SamplerObj {
/// \brief A function that calls build on the children of this sampler
/// \param[in] sampler The samplerRT object built from this sampler
/// \return the Status code returned
Status BuildChildren(std::shared_ptr<SamplerRT> *sampler);
Status BuildChildren(std::shared_ptr<SamplerRT> *const sampler);

std::vector<std::shared_ptr<SamplerObj>> children_;
};
@@ -111,7 +111,7 @@ class DistributedSamplerObj : public SamplerObj {
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status to_json(nlohmann::json *const out_json) override;

Status ValidateParams() override;

@@ -152,7 +152,7 @@ class PKSamplerObj : public SamplerObj {
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status to_json(nlohmann::json *const out_json) override;

Status ValidateParams() override;

@@ -171,7 +171,7 @@ class PreBuiltSamplerObj : public SamplerObj {

~PreBuiltSamplerObj() = default;

Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override;
Status SamplerBuild(std::shared_ptr<SamplerRT> *const sampler) override;

#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
@@ -181,7 +181,7 @@ class PreBuiltSamplerObj : public SamplerObj {

Status ValidateParams() override;

Status to_json(nlohmann::json *out_json) override;
Status to_json(nlohmann::json *const out_json) override;

private:
std::shared_ptr<SamplerRT> sp_;
@@ -213,7 +213,7 @@ class RandomSamplerObj : public SamplerObj {
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status to_json(nlohmann::json *const out_json) override;

Status ValidateParams() override;

@@ -246,7 +246,7 @@ class SequentialSamplerObj : public SamplerObj {
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status to_json(nlohmann::json *const out_json) override;

Status ValidateParams() override;

@@ -278,7 +278,7 @@ class SubsetSamplerObj : public SamplerObj {
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status to_json(nlohmann::json *const out_json) override;

Status ValidateParams() override;

@@ -293,7 +293,7 @@ class SubsetRandomSamplerObj : public SubsetSamplerObj {

~SubsetRandomSamplerObj() = default;

Status to_json(nlohmann::json *out_json) override;
Status to_json(nlohmann::json *const out_json) override;

Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override;

@@ -331,7 +331,7 @@ class WeightedRandomSamplerObj : public SamplerObj {
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status to_json(nlohmann::json *const out_json) override;

Status ValidateParams() override;



Loading…
Cancel
Save