Browse Source

!13203 Add const to C++ Sampler IR functions

From: @mhmotallebi
Reviewed-by: @nsyca,@robingrosman
Signed-off-by: @robingrosman
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
654be6216e
4 changed files with 23 additions and 24 deletions
  1. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.cc
  2. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.h
  3. +10
    -11
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.cc
  4. +11
    -11
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h

+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.cc View File

@@ -74,7 +74,7 @@ std::shared_ptr<SamplerObj> CacheLookupNode::SamplerCopy() {
return std::static_pointer_cast<SamplerObj>(lookup_node_copy_); return std::static_pointer_cast<SamplerObj>(lookup_node_copy_);
} }


Status CacheLookupNode::SamplerBuild(std::shared_ptr<SamplerRT> *out) {
Status CacheLookupNode::SamplerBuild(std::shared_ptr<SamplerRT> *const out) {
// Runtime cache lookup op should already been built, so we just return it here // Runtime cache lookup op should already been built, so we just return it here
auto lookup_op = std::dynamic_pointer_cast<CacheLookupOp>(lookup_op_); auto lookup_op = std::dynamic_pointer_cast<CacheLookupOp>(lookup_op_);
*out = std::shared_ptr<SamplerRT>(lookup_op); *out = std::shared_ptr<SamplerRT>(lookup_op);


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.h View File

@@ -50,7 +50,7 @@ class CacheLookupNode : public DatasetNode, public SamplerObj {
/// \brief a base class override function to convert a SamplerObj class into a runtime sampler object /// \brief a base class override function to convert a SamplerObj class into a runtime sampler object
/// \param[out] out Shared pointer to the newly created Sampler /// \param[out] out Shared pointer to the newly created Sampler
/// \return The Status code of the function. It returns OK status if sampler is created successfully. /// \return The Status code of the function. It returns OK status if sampler is created successfully.
Status SamplerBuild(std::shared_ptr<SamplerRT> *out) override;
Status SamplerBuild(std::shared_ptr<SamplerRT> *const out) override;


/// \brief a base class override function to copy a SamplerObj class /// \brief a base class override function to copy a SamplerObj class
/// \return Shared pointers to the newly copied SamplerObj /// \return Shared pointers to the newly copied SamplerObj


+ 10
- 11
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.cc View File

@@ -42,7 +42,7 @@ namespace dataset {
// Constructor // Constructor
SamplerObj::SamplerObj() {} SamplerObj::SamplerObj() {}


Status SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> *sampler) {
Status SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> *const sampler) {
for (auto child : children_) { for (auto child : children_) {
std::shared_ptr<SamplerRT> sampler_rt = nullptr; std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(child->SamplerBuild(&sampler_rt)); RETURN_IF_NOT_OK(child->SamplerBuild(&sampler_rt));
@@ -133,7 +133,7 @@ std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDa
} }
#endif #endif


Status DistributedSamplerObj::to_json(nlohmann::json *out_json) {
Status DistributedSamplerObj::to_json(nlohmann::json *const out_json) {
nlohmann::json args; nlohmann::json args;
args["sampler_name"] = "DistributedSampler"; args["sampler_name"] = "DistributedSampler";
args["num_shards"] = num_shards_; args["num_shards"] = num_shards_;
@@ -170,7 +170,7 @@ Status PKSamplerObj::ValidateParams() {
return Status::OK(); return Status::OK();
} }


Status PKSamplerObj::to_json(nlohmann::json *out_json) {
Status PKSamplerObj::to_json(nlohmann::json *const out_json) {
nlohmann::json args; nlohmann::json args;
args["sampler_name"] = "PKSampler"; args["sampler_name"] = "PKSampler";
args["num_val"] = num_val_; args["num_val"] = num_val_;
@@ -222,13 +222,12 @@ PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator


Status PreBuiltSamplerObj::ValidateParams() { return Status::OK(); } Status PreBuiltSamplerObj::ValidateParams() { return Status::OK(); }


Status PreBuiltSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
Status PreBuiltSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *const sampler) {
Status s = BuildChildren(&sp_); Status s = BuildChildren(&sp_);
if (s.IsOk()) if (s.IsOk())
*sampler = sp_; *sampler = sp_;
else else
*sampler = nullptr; *sampler = nullptr;
// FIXME: what to do with sp_ if status is not OK?
return s; return s;
} }


@@ -253,7 +252,7 @@ std::shared_ptr<SamplerObj> PreBuiltSamplerObj::SamplerCopy() {
return sampler; return sampler;
} }


Status PreBuiltSamplerObj::to_json(nlohmann::json *out_json) {
Status PreBuiltSamplerObj::to_json(nlohmann::json *const out_json) {
RETURN_IF_NOT_OK(sp_->to_json(out_json)); RETURN_IF_NOT_OK(sp_->to_json(out_json));
return Status::OK(); return Status::OK();
} }
@@ -270,7 +269,7 @@ Status RandomSamplerObj::ValidateParams() {
return Status::OK(); return Status::OK();
} }


Status RandomSamplerObj::to_json(nlohmann::json *out_json) {
Status RandomSamplerObj::to_json(nlohmann::json *const out_json) {
nlohmann::json args; nlohmann::json args;
args["sampler_name"] = "RandomSampler"; args["sampler_name"] = "RandomSampler";
args["replacement"] = replacement_; args["replacement"] = replacement_;
@@ -325,7 +324,7 @@ Status SequentialSamplerObj::ValidateParams() {
return Status::OK(); return Status::OK();
} }


Status SequentialSamplerObj::to_json(nlohmann::json *out_json) {
Status SequentialSamplerObj::to_json(nlohmann::json *const out_json) {
nlohmann::json args; nlohmann::json args;
args["sampler_name"] = "SequentialSampler"; args["sampler_name"] = "SequentialSampler";
args["start_index"] = start_index_; args["start_index"] = start_index_;
@@ -389,7 +388,7 @@ std::shared_ptr<mindrecord::ShardOperator> SubsetSamplerObj::BuildForMindDataset
return mind_sampler; return mind_sampler;
} }
#endif #endif
Status SubsetSamplerObj::to_json(nlohmann::json *out_json) {
Status SubsetSamplerObj::to_json(nlohmann::json *const out_json) {
nlohmann::json args; nlohmann::json args;
args["sampler_name"] = "SubsetSampler"; args["sampler_name"] = "SubsetSampler";
args["indices"] = indices_; args["indices"] = indices_;
@@ -428,7 +427,7 @@ std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindD
} }
#endif #endif


Status SubsetRandomSamplerObj::to_json(nlohmann::json *out_json) {
Status SubsetRandomSamplerObj::to_json(nlohmann::json *const out_json) {
nlohmann::json args; nlohmann::json args;
args["sampler_name"] = "SubsetRandomSampler"; args["sampler_name"] = "SubsetRandomSampler";
args["indices"] = indices_; args["indices"] = indices_;
@@ -474,7 +473,7 @@ Status WeightedRandomSamplerObj::ValidateParams() {
return Status::OK(); return Status::OK();
} }


Status WeightedRandomSamplerObj::to_json(nlohmann::json *out_json) {
Status WeightedRandomSamplerObj::to_json(nlohmann::json *const out_json) {
nlohmann::json args; nlohmann::json args;
args["sampler_name"] = "WeightedRandomSampler"; args["sampler_name"] = "WeightedRandomSampler";
args["weights"] = weights_; args["weights"] = weights_;


+ 11
- 11
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h View File

@@ -65,7 +65,7 @@ class SamplerObj {
/// \return the Status code returned /// \return the Status code returned
Status AddChildSampler(std::shared_ptr<SamplerObj> child); Status AddChildSampler(std::shared_ptr<SamplerObj> child);


virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
virtual Status to_json(nlohmann::json *const out_json) { return Status::OK(); }


std::vector<std::shared_ptr<SamplerObj>> GetChild() { return children_; } std::vector<std::shared_ptr<SamplerObj>> GetChild() { return children_; }


@@ -80,7 +80,7 @@ class SamplerObj {
/// \brief A function that calls build on the children of this sampler /// \brief A function that calls build on the children of this sampler
/// \param[in] sampler The samplerRT object built from this sampler /// \param[in] sampler The samplerRT object built from this sampler
/// \return the Status code returned /// \return the Status code returned
Status BuildChildren(std::shared_ptr<SamplerRT> *sampler);
Status BuildChildren(std::shared_ptr<SamplerRT> *const sampler);


std::vector<std::shared_ptr<SamplerObj>> children_; std::vector<std::shared_ptr<SamplerObj>> children_;
}; };
@@ -111,7 +111,7 @@ class DistributedSamplerObj : public SamplerObj {
/// \brief Get the arguments of node /// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes /// \param[out] out_json JSON string of all attributes
/// \return Status of the function /// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status to_json(nlohmann::json *const out_json) override;


Status ValidateParams() override; Status ValidateParams() override;


@@ -152,7 +152,7 @@ class PKSamplerObj : public SamplerObj {
/// \brief Get the arguments of node /// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes /// \param[out] out_json JSON string of all attributes
/// \return Status of the function /// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status to_json(nlohmann::json *const out_json) override;


Status ValidateParams() override; Status ValidateParams() override;


@@ -171,7 +171,7 @@ class PreBuiltSamplerObj : public SamplerObj {


~PreBuiltSamplerObj() = default; ~PreBuiltSamplerObj() = default;


Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override;
Status SamplerBuild(std::shared_ptr<SamplerRT> *const sampler) override;


#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
@@ -181,7 +181,7 @@ class PreBuiltSamplerObj : public SamplerObj {


Status ValidateParams() override; Status ValidateParams() override;


Status to_json(nlohmann::json *out_json) override;
Status to_json(nlohmann::json *const out_json) override;


private: private:
std::shared_ptr<SamplerRT> sp_; std::shared_ptr<SamplerRT> sp_;
@@ -213,7 +213,7 @@ class RandomSamplerObj : public SamplerObj {
/// \brief Get the arguments of node /// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes /// \param[out] out_json JSON string of all attributes
/// \return Status of the function /// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status to_json(nlohmann::json *const out_json) override;


Status ValidateParams() override; Status ValidateParams() override;


@@ -246,7 +246,7 @@ class SequentialSamplerObj : public SamplerObj {
/// \brief Get the arguments of node /// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes /// \param[out] out_json JSON string of all attributes
/// \return Status of the function /// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status to_json(nlohmann::json *const out_json) override;


Status ValidateParams() override; Status ValidateParams() override;


@@ -278,7 +278,7 @@ class SubsetSamplerObj : public SamplerObj {
/// \brief Get the arguments of node /// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes /// \param[out] out_json JSON string of all attributes
/// \return Status of the function /// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status to_json(nlohmann::json *const out_json) override;


Status ValidateParams() override; Status ValidateParams() override;


@@ -293,7 +293,7 @@ class SubsetRandomSamplerObj : public SubsetSamplerObj {


~SubsetRandomSamplerObj() = default; ~SubsetRandomSamplerObj() = default;


Status to_json(nlohmann::json *out_json) override;
Status to_json(nlohmann::json *const out_json) override;


Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override; Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override;


@@ -331,7 +331,7 @@ class WeightedRandomSamplerObj : public SamplerObj {
/// \brief Get the arguments of node /// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes /// \param[out] out_json JSON string of all attributes
/// \return Status of the function /// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status to_json(nlohmann::json *const out_json) override;


Status ValidateParams() override; Status ValidateParams() override;




Loading…
Cancel
Save