| @@ -2,4 +2,5 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc" | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| add_library(engine-ir-cache OBJECT | |||
| pre_built_dataset_cache.cc | |||
| dataset_cache_impl.cc) | |||
| dataset_cache_impl.cc | |||
| dataset_cache.cc) | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/ir/cache/dataset_cache.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include <optional> | |||
| #include <vector> | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h" | |||
| #endif | |||
| namespace mindspore::dataset { | |||
| #ifndef ENABLE_ANDROID | |||
| Status DatasetCache::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetCache> *cache) { | |||
| if (json_obj.find("cache") != json_obj.end()) { | |||
| nlohmann::json json_cache = json_obj["cache"]; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_cache.find("session_id") != json_cache.end(), "Failed to find session_id"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_cache.find("cache_memory_size") != json_cache.end(), | |||
| "Failed to find cache_memory_size"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_cache.find("spill") != json_cache.end(), "Failed to find spill"); | |||
| session_id_type id = static_cast<session_id_type>(json_cache["session_id"]); | |||
| uint64_t mem_sz = json_cache["cache_memory_size"]; | |||
| bool spill = json_cache["spill"]; | |||
| std::optional<std::vector<char>> hostname_c = std::nullopt; | |||
| std::optional<int32_t> port = std::nullopt; | |||
| std::optional<int32_t> num_connections = std::nullopt; | |||
| std::optional<int32_t> prefetch_sz = std::nullopt; | |||
| if (json_cache.find("hostname") != json_cache.end()) { | |||
| std::optional<std::string> hostname = json_cache["hostname"]; | |||
| hostname_c = std::vector<char>(hostname->begin(), hostname->end()); | |||
| } | |||
| if (json_cache.find("port") != json_cache.end()) port = json_cache["port"]; | |||
| if (json_cache.find("num_connections") != json_cache.end()) num_connections = json_cache["num_connections"]; | |||
| if (json_cache.find("prefetch_size") != json_cache.end()) prefetch_sz = json_cache["prefetch_size"]; | |||
| *cache = std::make_shared<DatasetCacheImpl>(id, mem_sz, spill, hostname_c, port, num_connections, prefetch_sz); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| } // namespace mindspore::dataset | |||
| @@ -35,6 +35,10 @@ class DatasetCache { | |||
| virtual Status CreateCacheMergeOp(int32_t num_workers, int32_t connector_queue_size, | |||
| std::shared_ptr<DatasetOp> *ds) = 0; | |||
| virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); } | |||
| #ifndef ENABLE_ANDROID | |||
| static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetCache> *cache); | |||
| #endif | |||
| }; | |||
| } // namespace mindspore::dataset | |||
| @@ -169,5 +169,19 @@ Status BatchNode::to_json(nlohmann::json *out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| Status BatchNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds, | |||
| std::shared_ptr<DatasetNode> *result) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(), | |||
| "Failed to find num_parallel_workers"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("batch_size") != json_obj.end(), "Failed to find batch_size"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("drop_remainder") != json_obj.end(), "Failed to find drop_remainder"); | |||
| int32_t batch_size = json_obj["batch_size"]; | |||
| bool drop_remainder = json_obj["drop_remainder"]; | |||
| *result = std::make_shared<BatchNode>(ds, batch_size, drop_remainder); | |||
| (*result)->SetNumWorkers(json_obj["num_parallel_workers"]); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -105,6 +105,14 @@ class BatchNode : public DatasetNode { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| /// \brief Function for read dataset operation from json | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[in] ds dataset node constructed | |||
| /// \param[out] result Deserialized dataset after the operation | |||
| /// \return Status The status code returned | |||
| static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds, | |||
| std::shared_ptr<DatasetNode> *result); | |||
| private: | |||
| int32_t batch_size_; | |||
| bool drop_remainder_; | |||
| @@ -22,6 +22,9 @@ | |||
| #include <utility> | |||
| #include <vector> | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/serdes.h" | |||
| #endif | |||
| #include "minddata/dataset/engine/datasetops/map_op/map_op.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/kernels/ir/tensor_operation.h" | |||
| @@ -154,7 +157,6 @@ Status MapNode::to_json(nlohmann::json *out_json) { | |||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||
| args["cache"] = cache_args; | |||
| } | |||
| std::vector<nlohmann::json> ops; | |||
| std::vector<int32_t> cbs; | |||
| for (auto op : operations_) { | |||
| @@ -177,6 +179,26 @@ Status MapNode::to_json(nlohmann::json *out_json) { | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status MapNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds, | |||
| std::shared_ptr<DatasetNode> *result) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(), | |||
| "Failed to find num_parallel_workers"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("input_columns") != json_obj.end(), "Failed to find input_columns"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("output_columns") != json_obj.end(), "Failed to find output_columns"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("project_columns") != json_obj.end(), "Failed to find project_columns"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("operations") != json_obj.end(), "Failed to find operations"); | |||
| std::vector<std::string> input_columns = json_obj["input_columns"]; | |||
| std::vector<std::string> output_columns = json_obj["output_columns"]; | |||
| std::vector<std::string> project_columns = json_obj["project_columns"]; | |||
| std::vector<std::shared_ptr<TensorOperation>> operations; | |||
| RETURN_IF_NOT_OK(Serdes::ConstructTensorOps(json_obj["operations"], &operations)); | |||
| *result = std::make_shared<MapNode>(ds, operations, input_columns, output_columns, project_columns); | |||
| (*result)->SetNumWorkers(json_obj["num_parallel_workers"]); | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| // Gets the dataset size | |||
| Status MapNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| @@ -93,6 +93,16 @@ class MapNode : public DatasetNode { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Function for read dataset operation from json | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[in] ds dataset node constructed | |||
| /// \param[out] result Deserialized dataset after the operation | |||
| /// \return Status The status code returned | |||
| static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds, | |||
| std::shared_ptr<DatasetNode> *result); | |||
| #endif | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| @@ -66,5 +66,13 @@ Status ProjectNode::to_json(nlohmann::json *out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| Status ProjectNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds, | |||
| std::shared_ptr<DatasetNode> *result) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("columns") != json_obj.end(), "Failed to find columns"); | |||
| std::vector<std::string> columns = json_obj["columns"]; | |||
| *result = std::make_shared<ProjectNode>(ds, columns); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -63,6 +63,14 @@ class ProjectNode : public DatasetNode { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| /// \brief Function for read dataset operation from json | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[in] ds dataset node constructed | |||
| /// \param[out] result Deserialized dataset after the operation | |||
| /// \return Status The status code returned | |||
| static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds, | |||
| std::shared_ptr<DatasetNode> *result); | |||
| private: | |||
| std::vector<std::string> columns_; | |||
| }; | |||
| @@ -72,5 +72,16 @@ Status RenameNode::to_json(nlohmann::json *out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| Status RenameNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds, | |||
| std::shared_ptr<DatasetNode> *result) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("input_columns") != json_obj.end(), "Failed to find input_columns"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("output_columns") != json_obj.end(), "Failed to find output_columns"); | |||
| std::vector<std::string> input_columns = json_obj["input_columns"]; | |||
| std::vector<std::string> output_columns = json_obj["output_columns"]; | |||
| *result = std::make_shared<RenameNode>(ds, input_columns, output_columns); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -65,6 +65,14 @@ class RenameNode : public DatasetNode { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| /// \brief Function for read dataset operation from json | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[in] ds dataset node constructed | |||
| /// \param[out] result Deserialized dataset after the operation | |||
| /// \return Status The status code returned | |||
| static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds, | |||
| std::shared_ptr<DatasetNode> *result); | |||
| private: | |||
| std::vector<std::string> input_columns_; | |||
| std::vector<std::string> output_columns_; | |||
| @@ -104,5 +104,14 @@ Status RepeatNode::to_json(nlohmann::json *out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| Status RepeatNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds, | |||
| std::shared_ptr<DatasetNode> *result) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("count") != json_obj.end(), "Failed to find count"); | |||
| int32_t count = json_obj["count"]; | |||
| *result = std::make_shared<RepeatNode>(ds, count); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -123,6 +123,14 @@ class RepeatNode : public DatasetNode { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| /// \brief Function for read dataset operation from json | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[in] ds dataset node constructed | |||
| /// \param[out] result Deserialized dataset after the operation | |||
| /// \return Status The status code returned | |||
| static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds, | |||
| std::shared_ptr<DatasetNode> *result); | |||
| protected: | |||
| std::shared_ptr<RepeatOp> op_; // keep its corresponding run-time op of EpochCtrlNode and RepeatNode | |||
| std::shared_ptr<RepeatNode> reset_ancestor_; // updated its immediate Repeat/EpochCtrl ancestor in GeneratorNodePass | |||
| @@ -66,9 +66,19 @@ Status ShuffleNode::ValidateParams() { | |||
| Status ShuffleNode::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["buffer_size"] = shuffle_size_; | |||
| args["reshuffle_each_epoch"] = reset_every_epoch_; | |||
| args["reset_each_epoch"] = reset_every_epoch_; | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| Status ShuffleNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds, | |||
| std::shared_ptr<DatasetNode> *result) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("buffer_size") != json_obj.end(), "Failed to find buffer_size"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("reset_each_epoch") != json_obj.end(), "Failed to find reset_each_epoch"); | |||
| int32_t buffer_size = json_obj["buffer_size"]; | |||
| bool reset_every_epoch = json_obj["reset_each_epoch"]; | |||
| *result = std::make_shared<ShuffleNode>(ds, buffer_size, reset_every_epoch); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -63,6 +63,14 @@ class ShuffleNode : public DatasetNode { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| /// \brief Function for read dataset operation from json | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[in] ds dataset node constructed | |||
| /// \param[out] result Deserialized dataset after the operation | |||
| /// \return Status The status code returned | |||
| static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds, | |||
| std::shared_ptr<DatasetNode> *result); | |||
| private: | |||
| int32_t shuffle_size_; | |||
| uint32_t shuffle_seed_; | |||
| @@ -93,5 +93,13 @@ Status SkipNode::to_json(nlohmann::json *out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| Status SkipNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds, | |||
| std::shared_ptr<DatasetNode> *result) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("count") != json_obj.end(), "Failed to find count"); | |||
| int32_t count = json_obj["count"]; | |||
| *result = std::make_shared<SkipNode>(ds, count); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -88,6 +88,14 @@ class SkipNode : public DatasetNode { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| /// \brief Function for read dataset operation from json | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[in] ds dataset node constructed | |||
| /// \param[out] result Deserialized dataset after the operation | |||
| /// \return Status The status code returned | |||
| static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds, | |||
| std::shared_ptr<DatasetNode> *result); | |||
| private: | |||
| int32_t skip_count_; | |||
| }; | |||
| @@ -25,6 +25,9 @@ | |||
| #include "debug/common.h" | |||
| #include "minddata/dataset/engine/datasetops/source/celeba_op.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/serdes.h" | |||
| #endif | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -182,5 +185,28 @@ Status CelebANode::to_json(nlohmann::json *out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status CelebANode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(), | |||
| "Failed to find num_parallel_workers"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("extensions") != json_obj.end(), "Failed to find extension"); | |||
| std::string dataset_dir = json_obj["dataset_dir"]; | |||
| std::string usage = json_obj["usage"]; | |||
| std::shared_ptr<SamplerObj> sampler; | |||
| RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler)); | |||
| bool decode = json_obj["decode"]; | |||
| std::set<std::string> extension = json_obj["extensions"]; | |||
| std::shared_ptr<DatasetCache> cache = nullptr; | |||
| RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); | |||
| *ds = std::make_shared<CelebANode>(dataset_dir, usage, sampler, decode, extension, cache); | |||
| (*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -82,6 +82,14 @@ class CelebANode : public MappableSourceNode { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Function to read dataset in json | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[out] ds Deserialized dataset | |||
| /// \return Status The status code returned | |||
| static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| #endif | |||
| /// \brief Sampler getter | |||
| /// \return SamplerObj of the current node | |||
| std::shared_ptr<SamplerObj> Sampler() override { return sampler_; } | |||
| @@ -22,6 +22,9 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/source/cifar_op.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/serdes.h" | |||
| #endif | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| @@ -117,5 +120,24 @@ Status Cifar100Node::to_json(nlohmann::json *out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status Cifar100Node::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(), | |||
| "Failed to find num_parallel_workers"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler"); | |||
| std::string dataset_dir = json_obj["dataset_dir"]; | |||
| std::string usage = json_obj["usage"]; | |||
| std::shared_ptr<SamplerObj> sampler; | |||
| RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler)); | |||
| std::shared_ptr<DatasetCache> cache = nullptr; | |||
| RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); | |||
| *ds = std::make_shared<Cifar100Node>(dataset_dir, usage, sampler, cache); | |||
| (*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -78,6 +78,14 @@ class Cifar100Node : public MappableSourceNode { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Function to read dataset in json | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[out] ds Deserialized dataset | |||
| /// \return Status The status code returned | |||
| static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| #endif | |||
| /// \brief Sampler getter | |||
| /// \return SamplerObj of the current node | |||
| std::shared_ptr<SamplerObj> Sampler() override { return sampler_; } | |||
| @@ -22,6 +22,9 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/source/cifar_op.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/serdes.h" | |||
| #endif | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| @@ -118,5 +121,24 @@ Status Cifar10Node::to_json(nlohmann::json *out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status Cifar10Node::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(), | |||
| "Failed to find num_parallel_workers"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler"); | |||
| std::string dataset_dir = json_obj["dataset_dir"]; | |||
| std::string usage = json_obj["usage"]; | |||
| std::shared_ptr<SamplerObj> sampler; | |||
| RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler)); | |||
| std::shared_ptr<DatasetCache> cache = nullptr; | |||
| RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); | |||
| *ds = std::make_shared<Cifar10Node>(dataset_dir, usage, sampler, cache); | |||
| (*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -78,6 +78,14 @@ class Cifar10Node : public MappableSourceNode { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Function to read dataset in json | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[out] ds Deserialized dataset | |||
| /// \return Status The status code returned | |||
| static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| #endif | |||
| /// \brief Sampler getter | |||
| /// \return SamplerObj of the current node | |||
| std::shared_ptr<SamplerObj> Sampler() override { return sampler_; } | |||
| @@ -249,6 +249,29 @@ Status CLUENode::to_json(nlohmann::json *out_json) { | |||
| return Status::OK(); | |||
| } | |||
| Status CLUENode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(), | |||
| "Failed to find num_parallel_workers"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("task") != json_obj.end(), "Failed to find task"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id"); | |||
| std::vector<std::string> dataset_files = json_obj["dataset_dir"]; | |||
| std::string task = json_obj["task"]; | |||
| std::string usage = json_obj["usage"]; | |||
| int64_t num_samples = json_obj["num_samples"]; | |||
| ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]); | |||
| int32_t num_shards = json_obj["num_shards"]; | |||
| int32_t shard_id = json_obj["shard_id"]; | |||
| std::shared_ptr<DatasetCache> cache = nullptr; | |||
| RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); | |||
| *ds = std::make_shared<CLUENode>(dataset_files, task, usage, num_samples, shuffle, num_shards, shard_id, cache); | |||
| (*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); | |||
| return Status::OK(); | |||
| } | |||
| // Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent | |||
| // class. CLUE by itself is a non-mappable dataset that does not support sampling. However, if a cache operator is | |||
| // injected at some other place higher in the tree, that cache can inherit this sampler from the leaf, providing | |||
| @@ -86,6 +86,12 @@ class CLUENode : public NonMappableSourceNode { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| /// \brief Function to read dataset in json | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[out] ds Deserialized dataset | |||
| /// \return Status The status code returned | |||
| static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| /// \brief CLUE by itself is a non-mappable dataset that does not support sampling. | |||
| /// However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| /// inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| @@ -22,6 +22,9 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/source/coco_op.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/serdes.h" | |||
| #endif | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| @@ -181,6 +184,7 @@ Status CocoNode::to_json(nlohmann::json *out_json) { | |||
| args["annotation_file"] = annotation_file_; | |||
| args["task"] = task_; | |||
| args["decode"] = decode_; | |||
| args["extra_metadata"] = extra_metadata_; | |||
| if (cache_ != nullptr) { | |||
| nlohmann::json cache_args; | |||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||
| @@ -189,5 +193,30 @@ Status CocoNode::to_json(nlohmann::json *out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status CocoNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(), | |||
| "Failed to find num_parallel_workers"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("annotation_file") != json_obj.end(), "Failed to find annotation_file"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("task") != json_obj.end(), "Failed to find task"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("extra_metadata") != json_obj.end(), "Failed to find extra_metadata"); | |||
| std::string dataset_dir = json_obj["dataset_dir"]; | |||
| std::string annotation_file = json_obj["annotation_file"]; | |||
| std::string task = json_obj["task"]; | |||
| bool decode = json_obj["decode"]; | |||
| std::shared_ptr<SamplerObj> sampler; | |||
| RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler)); | |||
| std::shared_ptr<DatasetCache> cache = nullptr; | |||
| RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); | |||
| bool extra_metadata = json_obj["extra_metadata"]; | |||
| *ds = std::make_shared<CocoNode>(dataset_dir, annotation_file, task, decode, sampler, cache, extra_metadata); | |||
| (*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -81,6 +81,14 @@ class CocoNode : public MappableSourceNode { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Function to read dataset in json | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[out] ds Deserialized dataset | |||
| /// \return Status The status code returned | |||
| static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| #endif | |||
| /// \brief Sampler getter | |||
| /// \return SamplerObj of the current node | |||
| std::shared_ptr<SamplerObj> Sampler() override { return sampler_; } | |||
| @@ -187,6 +187,32 @@ Status CSVNode::to_json(nlohmann::json *out_json) { | |||
| return Status::OK(); | |||
| } | |||
| Status CSVNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(), | |||
| "Failed to find num_parallel_workers"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_files") != json_obj.end(), "Failed to find dataset_files"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("field_delim") != json_obj.end(), "Failed to find field_delim"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("column_names") != json_obj.end(), "Failed to find column_names"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id"); | |||
| std::vector<std::string> dataset_files = json_obj["dataset_files"]; | |||
| std::string field_delim = json_obj["field_delim"]; | |||
| std::vector<std::shared_ptr<CsvBase>> column_defaults = {}; | |||
| std::vector<std::string> column_names = json_obj["column_names"]; | |||
| int64_t num_samples = json_obj["num_samples"]; | |||
| ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]); | |||
| int32_t num_shards = json_obj["num_shards"]; | |||
| int32_t shard_id = json_obj["shard_id"]; | |||
| std::shared_ptr<DatasetCache> cache = nullptr; | |||
| RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); | |||
| *ds = std::make_shared<CSVNode>(dataset_files, field_delim.c_str()[0], column_defaults, column_names, num_samples, | |||
| shuffle, num_shards, shard_id, cache); | |||
| (*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); | |||
| return Status::OK(); | |||
| } | |||
| // Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class. | |||
| // CSV by itself is a non-mappable dataset that does not support sampling. | |||
| // However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| @@ -107,6 +107,12 @@ class CSVNode : public NonMappableSourceNode { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| /// \brief Function to read dataset in json | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[out] ds Deserialized dataset | |||
| /// \return Status The status code returned | |||
| static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| /// \brief CSV by itself is a non-mappable dataset that does not support sampling. | |||
| /// However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| /// inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| @@ -24,6 +24,9 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/serdes.h" | |||
| #endif | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| @@ -113,6 +116,7 @@ Status ImageFolderNode::to_json(nlohmann::json *out_json) { | |||
| args["sampler"] = sampler_args; | |||
| args["num_parallel_workers"] = num_workers_; | |||
| args["dataset_dir"] = dataset_dir_; | |||
| args["recursive"] = recursive_; | |||
| args["decode"] = decode_; | |||
| args["extensions"] = exts_; | |||
| args["class_indexing"] = class_indexing_; | |||
| @@ -124,5 +128,36 @@ Status ImageFolderNode::to_json(nlohmann::json *out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status ImageFolderNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(), | |||
| "Failed to find num_parallel_workers"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("recursive") != json_obj.end(), "Failed to find recursive"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("extensions") != json_obj.end(), "Failed to find extension"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("class_indexing") != json_obj.end(), "Failed to find class_indexing"); | |||
| std::string dataset_dir = json_obj["dataset_dir"]; | |||
| bool decode = json_obj["decode"]; | |||
| std::shared_ptr<SamplerObj> sampler; | |||
| RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler)); | |||
| bool recursive = json_obj["recursive"]; | |||
| std::set<std::string> extension = json_obj["extensions"]; | |||
| std::map<std::string, int32_t> class_indexing; | |||
| nlohmann::json class_map = json_obj["class_indexing"]; | |||
| for (const auto &class_map_child : class_map) { | |||
| std::string class_ = class_map_child[0]; | |||
| int32_t indexing = class_map_child[1]; | |||
| class_indexing.insert({class_, indexing}); | |||
| } | |||
| std::shared_ptr<DatasetCache> cache = nullptr; | |||
| RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); | |||
| *ds = std::make_shared<ImageFolderNode>(dataset_dir, decode, sampler, recursive, extension, class_indexing, cache); | |||
| (*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -87,6 +87,14 @@ class ImageFolderNode : public MappableSourceNode { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Function to read dataset in json | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[out] ds Deserialized dataset | |||
| /// \return Status The status code returned | |||
| static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| #endif | |||
| /// \brief Sampler getter | |||
| /// \return SamplerObj of the current node | |||
| std::shared_ptr<SamplerObj> Sampler() override { return sampler_; } | |||
| @@ -23,6 +23,9 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/source/manifest_op.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/serdes.h" | |||
| #endif | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| @@ -152,5 +155,34 @@ Status ManifestNode::to_json(nlohmann::json *out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status ManifestNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(), | |||
| "Failed to find num_parallel_workers"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_file") != json_obj.end(), "Failed to find dataset_file"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("class_indexing") != json_obj.end(), "Failed to find class_indexing"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode"); | |||
| std::string dataset_file = json_obj["dataset_file"]; | |||
| std::string usage = json_obj["usage"]; | |||
| std::shared_ptr<SamplerObj> sampler; | |||
| RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler)); | |||
| std::map<std::string, int32_t> class_indexing; | |||
| nlohmann::json class_map = json_obj["class_indexing"]; | |||
| for (const auto &class_map_child : class_map) { | |||
| std::string class_ = class_map_child[0]; | |||
| int32_t indexing = class_map_child[1]; | |||
| class_indexing.insert({class_, indexing}); | |||
| } | |||
| bool decode = json_obj["decode"]; | |||
| std::shared_ptr<DatasetCache> cache = nullptr; | |||
| RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); | |||
| *ds = std::make_shared<ManifestNode>(dataset_file, usage, sampler, class_indexing, decode, cache); | |||
| (*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -78,9 +78,18 @@ class ManifestNode : public MappableSourceNode { | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \param[in] cache Dataset cache for constructor input | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Function to read dataset in json | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[out] ds Deserialized dataset | |||
| /// \return Status The status code returned | |||
| static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| #endif | |||
| /// \brief Sampler getter | |||
| /// \return SamplerObj of the current node | |||
| std::shared_ptr<SamplerObj> Sampler() override { return sampler_; } | |||
| @@ -22,6 +22,9 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/source/mnist_op.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/serdes.h" | |||
| #endif | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| @@ -111,5 +114,24 @@ Status MnistNode::to_json(nlohmann::json *out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status MnistNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(), | |||
| "Failed to find num_parallel_workers"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler"); | |||
| std::string dataset_dir = json_obj["dataset_dir"]; | |||
| std::string usage = json_obj["usage"]; | |||
| std::shared_ptr<SamplerObj> sampler; | |||
| RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler)); | |||
| std::shared_ptr<DatasetCache> cache = nullptr; | |||
| RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); | |||
| *ds = std::make_shared<MnistNode>(dataset_dir, usage, sampler, cache); | |||
| (*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -78,6 +78,14 @@ class MnistNode : public MappableSourceNode { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Function to read dataset in json | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[out] ds Deserialized dataset | |||
| /// \return Status The status code returned | |||
| static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| #endif | |||
| /// \brief Sampler getter | |||
| /// \return SamplerObj of the current node | |||
| std::shared_ptr<SamplerObj> Sampler() override { return sampler_; } | |||
| @@ -106,6 +106,30 @@ Status DistributedSamplerObj::to_json(nlohmann::json *const out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status DistributedSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples, | |||
| std::shared_ptr<SamplerObj> *sampler) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("seed") != json_obj.end(), "Failed to find seed"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("offset") != json_obj.end(), "Failed to find offset"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("even_dist") != json_obj.end(), "Failed to find even_dist"); | |||
| int64_t num_shards = json_obj["num_shards"]; | |||
| int64_t shard_id = json_obj["shard_id"]; | |||
| bool shuffle = json_obj["shuffle"]; | |||
| uint32_t seed = json_obj["seed"]; | |||
| int64_t offset = json_obj["offset"]; | |||
| bool even_dist = json_obj["even_dist"]; | |||
| *sampler = | |||
| std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist); | |||
| // Run common code in super class to add children samplers | |||
| RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler)); | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| std::shared_ptr<SamplerObj> DistributedSamplerObj::SamplerCopy() { | |||
| auto sampler = | |||
| std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, offset_, even_dist_); | |||
| @@ -56,6 +56,15 @@ class DistributedSamplerObj : public SamplerObj { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *const out_json) override; | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Function for read sampler from JSON object | |||
| /// \param[in] json_obj JSON object to be read | |||
| /// \param[in] num_samples number of sample in the sampler | |||
| /// \param[out] sampler Sampler constructed from parameters in JSON object | |||
| /// \return Status of the function | |||
| static Status from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler); | |||
| #endif | |||
| Status ValidateParams() override; | |||
| /// \brief Function to get the shard id of sampler | |||
| @@ -60,6 +60,19 @@ Status PKSamplerObj::to_json(nlohmann::json *const out_json) { | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status PKSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_val") != json_obj.end(), "Failed to find num_val"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle"); | |||
| int64_t num_val = json_obj["num_val"]; | |||
| bool shuffle = json_obj["shuffle"]; | |||
| *sampler = std::make_shared<PKSamplerObj>(num_val, shuffle, num_samples); | |||
| // Run common code in super class to add children samplers | |||
| RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler)); | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| Status PKSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) { | |||
| // runtime sampler object | |||
| *sampler = std::make_shared<dataset::PKSamplerRT>(num_val_, shuffle_, num_samples_); | |||
| @@ -55,6 +55,15 @@ class PKSamplerObj : public SamplerObj { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *const out_json) override; | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Function for read sampler from JSON object | |||
| /// \param[in] json_obj JSON object to be read | |||
| /// \param[in] num_samples number of sample in the sampler | |||
| /// \param[out] sampler Sampler constructed from parameters in JSON object | |||
| /// \return Status of the function | |||
| static Status from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler); | |||
| #endif | |||
| Status ValidateParams() override; | |||
| private: | |||
| @@ -56,6 +56,20 @@ Status RandomSamplerObj::to_json(nlohmann::json *const out_json) { | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status RandomSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("replacement") != json_obj.end(), "Failed to find replacement"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("reshuffle_each_epoch") != json_obj.end(), | |||
| "Failed to find reshuffle_each_epoch"); | |||
| bool replacement = json_obj["replacement"]; | |||
| bool reshuffle_each_epoch = json_obj["reshuffle_each_epoch"]; | |||
| *sampler = std::make_shared<RandomSamplerObj>(replacement, num_samples, reshuffle_each_epoch); | |||
| // Run common code in super class to add children samplers | |||
| RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler)); | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| Status RandomSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) { | |||
| // runtime sampler object | |||
| *sampler = std::make_shared<dataset::RandomSamplerRT>(replacement_, num_samples_, reshuffle_each_epoch_); | |||
| @@ -55,6 +55,15 @@ class RandomSamplerObj : public SamplerObj { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *const out_json) override; | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Function for read sampler from JSON object | |||
| /// \param[in] json_obj JSON object to be read | |||
| /// \param[in] num_samples number of sample in the sampler | |||
| /// \param[out] sampler Sampler constructed from parameters in JSON object | |||
| /// \return Status of the function | |||
| static Status from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler); | |||
| #endif | |||
| Status ValidateParams() override; | |||
| private: | |||
| @@ -16,6 +16,9 @@ | |||
| #include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/serdes.h" | |||
| #endif | |||
| #include "minddata/dataset/core/config_manager.h" | |||
| @@ -73,5 +76,15 @@ Status SamplerObj::to_json(nlohmann::json *const out_json) { | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status SamplerObj::from_json(nlohmann::json json_obj, std::shared_ptr<SamplerObj> *parent_sampler) { | |||
| for (nlohmann::json child : json_obj["child_sampler"]) { | |||
| std::shared_ptr<SamplerObj> child_sampler; | |||
| RETURN_IF_NOT_OK(Serdes::ConstructSampler(child, &child_sampler)); | |||
| (*parent_sampler)->AddChildSampler(child_sampler); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -67,6 +67,14 @@ class SamplerObj { | |||
| virtual Status to_json(nlohmann::json *const out_json); | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Function to construct children samplers | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[out] parent_sampler given parent sampler, output constructed parent sampler with children samplers added | |||
| /// \return Status The status code returned | |||
| static Status from_json(nlohmann::json json_obj, std::shared_ptr<SamplerObj> *parent_sampler); | |||
| #endif | |||
| std::vector<std::shared_ptr<SamplerObj>> GetChild() { return children_; } | |||
| #ifndef ENABLE_ANDROID | |||
| @@ -61,6 +61,18 @@ Status SequentialSamplerObj::to_json(nlohmann::json *const out_json) { | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status SequentialSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples, | |||
| std::shared_ptr<SamplerObj> *sampler) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("start_index") != json_obj.end(), "Failed to find start_index"); | |||
| int64_t start_index = json_obj["start_index"]; | |||
| *sampler = std::make_shared<SequentialSamplerObj>(start_index, num_samples); | |||
| // Run common code in super class to add children samplers | |||
| RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler)); | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| Status SequentialSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) { | |||
| // runtime sampler object | |||
| *sampler = std::make_shared<dataset::SequentialSamplerRT>(start_index_, num_samples_); | |||
| @@ -55,6 +55,15 @@ class SequentialSamplerObj : public SamplerObj { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *const out_json) override; | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Function for read sampler from JSON object | |||
| /// \param[in] json_obj JSON object to be read | |||
| /// \param[in] num_samples number of sample in the sampler | |||
| /// \param[out] sampler Sampler constructed from parameters in JSON object | |||
| /// \return Status of the function | |||
| static Status from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler); | |||
| #endif | |||
| Status ValidateParams() override; | |||
| private: | |||
| @@ -63,6 +63,19 @@ Status SubsetRandomSamplerObj::to_json(nlohmann::json *const out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status SubsetRandomSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples, | |||
| std::shared_ptr<SamplerObj> *sampler) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("indices") != json_obj.end(), "Failed to find indices"); | |||
| std::vector<int64_t> indices = json_obj["indices"]; | |||
| *sampler = std::make_shared<SubsetRandomSamplerObj>(indices, num_samples); | |||
| // Run common code in super class to add children samplers | |||
| RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler)); | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| std::shared_ptr<SamplerObj> SubsetRandomSamplerObj::SamplerCopy() { | |||
| auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_); | |||
| for (const auto &child : children_) { | |||
| @@ -45,6 +45,10 @@ class SubsetRandomSamplerObj : public SubsetSamplerObj { | |||
| Status to_json(nlohmann::json *const out_json) override; | |||
| #ifndef ENABLE_ANDROID | |||
| static Status from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler); | |||
| #endif | |||
| Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override; | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override; | |||
| @@ -72,6 +72,17 @@ Status SubsetSamplerObj::to_json(nlohmann::json *const out_json) { | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status SubsetSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("indices") != json_obj.end(), "Failed to find indices"); | |||
| std::vector<int64_t> indices = json_obj["indices"]; | |||
| *sampler = std::make_shared<SubsetSamplerObj>(indices, num_samples); | |||
| // Run common code in super class to add children samplers | |||
| RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler)); | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| std::shared_ptr<SamplerObj> SubsetSamplerObj::SamplerCopy() { | |||
| auto sampler = std::make_shared<SubsetSamplerObj>(indices_, num_samples_); | |||
| for (const auto &child : children_) { | |||
| @@ -55,6 +55,15 @@ class SubsetSamplerObj : public SamplerObj { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *const out_json) override; | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Function for read sampler from JSON object | |||
| /// \param[in] json_obj JSON object to be read | |||
| /// \param[in] num_samples number of sample in the sampler | |||
| /// \param[out] sampler Sampler constructed from parameters in JSON object | |||
| /// \return Status of the function | |||
| static Status from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler); | |||
| #endif | |||
| Status ValidateParams() override; | |||
| protected: | |||
| @@ -63,6 +63,20 @@ Status WeightedRandomSamplerObj::to_json(nlohmann::json *const out_json) { | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status WeightedRandomSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples, | |||
| std::shared_ptr<SamplerObj> *sampler) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("weights") != json_obj.end(), "Failed to find weights"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("replacement") != json_obj.end(), "Failed to find replacement"); | |||
| std::vector<double> weights = json_obj["weights"]; | |||
| bool replacement = json_obj["replacement"]; | |||
| *sampler = std::make_shared<WeightedRandomSamplerObj>(weights, num_samples, replacement); | |||
| // Run common code in super class to add children samplers | |||
| RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler)); | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| Status WeightedRandomSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) { | |||
| *sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(weights_, num_samples_, replacement_); | |||
| Status s = BuildChildren(sampler); | |||
| @@ -51,6 +51,15 @@ class WeightedRandomSamplerObj : public SamplerObj { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *const out_json) override; | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Function for read sampler from JSON object | |||
| /// \param[in] json_obj JSON object to be read | |||
| /// \param[in] num_samples number of sample in the sampler | |||
| /// \param[out] sampler Sampler constructed from parameters in JSON object | |||
| /// \return Status of the function | |||
| static Status from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler); | |||
| #endif | |||
| Status ValidateParams() override; | |||
| private: | |||
| @@ -153,6 +153,26 @@ Status TextFileNode::to_json(nlohmann::json *out_json) { | |||
| return Status::OK(); | |||
| } | |||
| Status TextFileNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(), | |||
| "Failed to find num_parallel_workers"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_files") != json_obj.end(), "Failed to find dataset_files"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id"); | |||
| std::vector<std::string> dataset_files = json_obj["dataset_files"]; | |||
| int64_t num_samples = json_obj["num_samples"]; | |||
| ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]); | |||
| int32_t num_shards = json_obj["num_shards"]; | |||
| int32_t shard_id = json_obj["shard_id"]; | |||
| std::shared_ptr<DatasetCache> cache = nullptr; | |||
| RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); | |||
| *ds = std::make_shared<TextFileNode>(dataset_files, num_samples, shuffle, num_shards, shard_id, cache); | |||
| (*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); | |||
| return Status::OK(); | |||
| } | |||
| // Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class. | |||
| // TextFile by itself is a non-mappable dataset that does not support sampling. | |||
| // However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| @@ -83,6 +83,12 @@ class TextFileNode : public NonMappableSourceNode { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| /// \brief Function to read dataset in json | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[out] ds Deserialized dataset | |||
| /// \return Status The status code returned | |||
| static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| /// \brief TextFile by itself is a non-mappable dataset that does not support sampling. | |||
| /// However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| /// inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| @@ -229,6 +229,33 @@ Status TFRecordNode::to_json(nlohmann::json *out_json) { | |||
| return Status::OK(); | |||
| } | |||
| Status TFRecordNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(), | |||
| "Failed to find num_parallel_workers"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_files") != json_obj.end(), "Failed to find dataset_files"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("schema") != json_obj.end(), "Failed to find schema"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("columns_list") != json_obj.end(), "Failed to find columns_list"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_equal_rows") != json_obj.end(), "Failed to find shard_equal_rows"); | |||
| std::vector<std::string> dataset_files = json_obj["dataset_files"]; | |||
| std::string schema = json_obj["schema"]; | |||
| std::vector<std::string> columns_list = json_obj["columns_list"]; | |||
| int64_t num_samples = json_obj["num_samples"]; | |||
| ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]); | |||
| int32_t num_shards = json_obj["num_shards"]; | |||
| int32_t shard_id = json_obj["shard_id"]; | |||
| bool shard_equal_rows = json_obj["shard_equal_rows"]; | |||
| std::shared_ptr<DatasetCache> cache = nullptr; | |||
| RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); | |||
| *ds = std::make_shared<TFRecordNode>(dataset_files, schema, columns_list, num_samples, shuffle, num_shards, shard_id, | |||
| shard_equal_rows, cache); | |||
| (*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); | |||
| return Status::OK(); | |||
| } | |||
| // Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class. | |||
| // TFRecord by itself is a non-mappable dataset that does not support sampling. | |||
| // However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| @@ -126,6 +126,12 @@ class TFRecordNode : public NonMappableSourceNode { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| /// \brief Function to read dataset in json | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[out] ds Deserialized dataset | |||
| /// \return Status The status code returned | |||
| static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| /// \brief TFRecord by itself is a non-mappable dataset that does not support sampling. | |||
| /// However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| /// inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| @@ -23,6 +23,9 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/source/voc_op.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/serdes.h" | |||
| #endif | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| @@ -169,6 +172,7 @@ Status VOCNode::to_json(nlohmann::json *out_json) { | |||
| args["usage"] = usage_; | |||
| args["class_indexing"] = class_index_; | |||
| args["decode"] = decode_; | |||
| args["extra_metadata"] = extra_metadata_; | |||
| if (cache_ != nullptr) { | |||
| nlohmann::json cache_args; | |||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||
| @@ -177,5 +181,38 @@ Status VOCNode::to_json(nlohmann::json *out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status VOCNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(), | |||
| "Failed to find num_parallel_workers"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("task") != json_obj.end(), "Failed to find task"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("class_indexing") != json_obj.end(), "Failed to find class_indexing"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("extra_metadata") != json_obj.end(), "Failed to find extra_metadata"); | |||
| std::string dataset_dir = json_obj["dataset_dir"]; | |||
| std::string task = json_obj["task"]; | |||
| std::string usage = json_obj["usage"]; | |||
| std::map<std::string, int32_t> class_indexing; | |||
| nlohmann::json class_map = json_obj["class_indexing"]; | |||
| for (const auto &class_map_child : class_map) { | |||
| std::string class_ = class_map_child[0]; | |||
| int32_t indexing = class_map_child[1]; | |||
| class_indexing.insert({class_, indexing}); | |||
| } | |||
| bool decode = json_obj["decode"]; | |||
| std::shared_ptr<SamplerObj> sampler; | |||
| RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler)); | |||
| bool extra_metadata = json_obj["extra_metadata"]; | |||
| std::shared_ptr<DatasetCache> cache = nullptr; | |||
| RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); | |||
| *ds = std::make_shared<VOCNode>(dataset_dir, task, usage, class_indexing, decode, sampler, cache, extra_metadata); | |||
| (*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -83,6 +83,14 @@ class VOCNode : public MappableSourceNode { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Function to read dataset in json | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[out] ds Deserialized dataset | |||
| /// \return Status The status code returned | |||
| static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| #endif | |||
| /// \brief Sampler getter | |||
| /// \return SamplerObj of the current node | |||
| std::shared_ptr<SamplerObj> Sampler() override { return sampler_; } | |||
| @@ -91,5 +91,13 @@ Status TakeNode::to_json(nlohmann::json *out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| Status TakeNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds, | |||
| std::shared_ptr<DatasetNode> *result) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("count") != json_obj.end(), "Failed to find count"); | |||
| int32_t count = json_obj["count"]; | |||
| *result = std::make_shared<TakeNode>(ds, count); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -88,6 +88,14 @@ class TakeNode : public DatasetNode { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| /// \brief Function for read dataset operation from json | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[in] ds dataset node constructed | |||
| /// \param[out] result Deserialized dataset after the operation | |||
| /// \return Status The status code returned | |||
| static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds, | |||
| std::shared_ptr<DatasetNode> *result); | |||
| private: | |||
| int32_t take_count_; | |||
| }; | |||
| @@ -126,5 +126,25 @@ Status TransferNode::to_json(nlohmann::json *out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| Status TransferNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds, | |||
| std::shared_ptr<DatasetNode> *result) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("queue_name") != json_obj.end(), "Failed to find queue_name"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("device_type") != json_obj.end(), "Failed to find device_type"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("device_id") != json_obj.end(), "Failed to find device_id"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("send_epoch_end") != json_obj.end(), "Failed to find send_epoch_end"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("total_batch") != json_obj.end(), "Failed to find total_batch"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("create_data_info_queue") != json_obj.end(), | |||
| "Failed to find create_data_info_queue"); | |||
| std::string queue_name = json_obj["queue_name"]; | |||
| std::string device_type = json_obj["device_type"]; | |||
| int32_t device_id = json_obj["device_id"]; | |||
| bool send_epoch_end = json_obj["send_epoch_end"]; | |||
| int32_t total_batch = json_obj["total_batch"]; | |||
| bool create_data_info_queue = json_obj["create_data_info_queue"]; | |||
| *result = std::make_shared<TransferNode>(ds, queue_name, device_type, device_id, send_epoch_end, total_batch, | |||
| create_data_info_queue); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -84,6 +84,14 @@ class TransferNode : public DatasetNode { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| /// \brief Function for read dataset operation from json | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[in] ds dataset node constructed | |||
| /// \param[out] result Deserialized dataset after the operation | |||
| /// \return Status The status code returned | |||
| static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds, | |||
| std::shared_ptr<DatasetNode> *result); | |||
| private: | |||
| std::string queue_name_; | |||
| int32_t device_id_; | |||
| @@ -124,584 +124,97 @@ Status Serdes::CreateNode(std::shared_ptr<DatasetNode> child_ds, nlohmann::json | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::CreateCelebADatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("extensions") != json_obj.end(), "Failed to find extension"); | |||
| std::string dataset_dir = json_obj["dataset_dir"]; | |||
| std::string usage = json_obj["usage"]; | |||
| std::shared_ptr<SamplerObj> sampler; | |||
| RETURN_IF_NOT_OK(ConstructSampler(json_obj["sampler"], &sampler)); | |||
| bool decode = json_obj["decode"]; | |||
| std::set<std::string> extension = json_obj["extensions"]; | |||
| // default value for cache - to_json function does not have the output | |||
| std::shared_ptr<DatasetCache> cache = nullptr; | |||
| *ds = std::make_shared<CelebANode>(dataset_dir, usage, sampler, decode, extension, cache); | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::CreateCifar10DatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler"); | |||
| std::string dataset_dir = json_obj["dataset_dir"]; | |||
| std::string usage = json_obj["usage"]; | |||
| std::shared_ptr<SamplerObj> sampler; | |||
| RETURN_IF_NOT_OK(ConstructSampler(json_obj["sampler"], &sampler)); | |||
| // default value for cache - to_json function does not have the output | |||
| std::shared_ptr<DatasetCache> cache = nullptr; | |||
| *ds = std::make_shared<Cifar10Node>(dataset_dir, usage, sampler, cache); | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::CreateCifar100DatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler"); | |||
| std::string dataset_dir = json_obj["dataset_dir"]; | |||
| std::string usage = json_obj["usage"]; | |||
| std::shared_ptr<SamplerObj> sampler; | |||
| RETURN_IF_NOT_OK(ConstructSampler(json_obj["sampler"], &sampler)); | |||
| // default value for cache - to_json function does not have the output | |||
| std::shared_ptr<DatasetCache> cache = nullptr; | |||
| *ds = std::make_shared<Cifar100Node>(dataset_dir, usage, sampler, cache); | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::CreateCLUEDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("task") != json_obj.end(), "Failed to find task"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id"); | |||
| std::vector<std::string> dataset_files = json_obj["dataset_dir"]; | |||
| std::string task = json_obj["task"]; | |||
| std::string usage = json_obj["usage"]; | |||
| int64_t num_samples = json_obj["num_samples"]; | |||
| ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]); | |||
| int32_t num_shards = json_obj["num_shards"]; | |||
| int32_t shard_id = json_obj["shard_id"]; | |||
| // default value for cache - to_json function does not have the output | |||
| std::shared_ptr<DatasetCache> cache = nullptr; | |||
| *ds = std::make_shared<CLUENode>(dataset_files, task, usage, num_samples, shuffle, num_shards, shard_id, cache); | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::CreateCocoDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("annotation_file") != json_obj.end(), "Failed to find annotation_file"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("task") != json_obj.end(), "Failed to find task"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler"); | |||
| std::string dataset_dir = json_obj["dataset_dir"]; | |||
| std::string annotation_file = json_obj["annotation_file"]; | |||
| std::string task = json_obj["task"]; | |||
| bool decode = json_obj["decode"]; | |||
| std::shared_ptr<SamplerObj> sampler; | |||
| RETURN_IF_NOT_OK(ConstructSampler(json_obj["sampler"], &sampler)); | |||
| // default value for cache and extra_metadata - to_json function does not have the output | |||
| std::shared_ptr<DatasetCache> cache = nullptr; | |||
| bool extra_metadata = false; | |||
| *ds = std::make_shared<CocoNode>(dataset_dir, annotation_file, task, decode, sampler, cache, extra_metadata); | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::CreateCSVDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_files") != json_obj.end(), "Failed to find dataset_files"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("field_delim") != json_obj.end(), "Failed to find field_delim"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("column_names") != json_obj.end(), "Failed to find column_names"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id"); | |||
| std::vector<std::string> dataset_files = json_obj["dataset_files"]; | |||
| std::string field_delim = json_obj["field_delim"]; | |||
| std::vector<std::shared_ptr<CsvBase>> column_defaults = {}; | |||
| std::vector<std::string> column_names = json_obj["column_names"]; | |||
| int64_t num_samples = json_obj["num_samples"]; | |||
| ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]); | |||
| int32_t num_shards = json_obj["num_shards"]; | |||
| int32_t shard_id = json_obj["shard_id"]; | |||
| // default value for cache - to_json function does not have the output | |||
| std::shared_ptr<DatasetCache> cache = nullptr; | |||
| *ds = std::make_shared<CSVNode>(dataset_files, field_delim.c_str()[0], column_defaults, column_names, num_samples, | |||
| shuffle, num_shards, shard_id, cache); | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::CreateImageFolderDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("extensions") != json_obj.end(), "Failed to find extension"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("class_indexing") != json_obj.end(), "Failed to find class_indexing"); | |||
| std::string dataset_dir = json_obj["dataset_dir"]; | |||
| bool decode = json_obj["decode"]; | |||
| std::shared_ptr<SamplerObj> sampler; | |||
| RETURN_IF_NOT_OK(ConstructSampler(json_obj["sampler"], &sampler)); | |||
| // This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false. | |||
| bool recursive = false; | |||
| std::set<std::string> extension = json_obj["extensions"]; | |||
| std::map<std::string, int32_t> class_indexing; | |||
| nlohmann::json class_map = json_obj["class_indexing"]; | |||
| for (const auto &class_map_child : class_map) { | |||
| std::string class_ = class_map_child[0]; | |||
| int32_t indexing = class_map_child[1]; | |||
| class_indexing.insert({class_, indexing}); | |||
| } | |||
| // default value for cache - to_json function does not have the output | |||
| std::shared_ptr<DatasetCache> cache = nullptr; | |||
| *ds = std::make_shared<ImageFolderNode>(dataset_dir, decode, sampler, recursive, extension, class_indexing, cache); | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::CreateManifestDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_file") != json_obj.end(), "Failed to find dataset_file"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("class_indexing") != json_obj.end(), "Failed to find class_indexing"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode"); | |||
| std::string dataset_file = json_obj["dataset_file"]; | |||
| std::string usage = json_obj["usage"]; | |||
| std::shared_ptr<SamplerObj> sampler; | |||
| RETURN_IF_NOT_OK(ConstructSampler(json_obj["sampler"], &sampler)); | |||
| std::map<std::string, int32_t> class_indexing; | |||
| nlohmann::json class_map = json_obj["class_indexing"]; | |||
| for (const auto &class_map_child : class_map) { | |||
| std::string class_ = class_map_child[0]; | |||
| int32_t indexing = class_map_child[1]; | |||
| class_indexing.insert({class_, indexing}); | |||
| } | |||
| bool decode = json_obj["decode"]; | |||
| // default value for cache - to_json function does not have the output | |||
| std::shared_ptr<DatasetCache> cache = nullptr; | |||
| *ds = std::make_shared<ManifestNode>(dataset_file, usage, sampler, class_indexing, decode, cache); | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::CreateMnistDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler"); | |||
| std::string dataset_dir = json_obj["dataset_dir"]; | |||
| std::string usage = json_obj["usage"]; | |||
| std::shared_ptr<SamplerObj> sampler; | |||
| RETURN_IF_NOT_OK(ConstructSampler(json_obj["sampler"], &sampler)); | |||
| // default value for cache - to_json function does not have the output | |||
| std::shared_ptr<DatasetCache> cache = nullptr; | |||
| *ds = std::make_shared<MnistNode>(dataset_dir, usage, sampler, cache); | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::CreateTextFileDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_files") != json_obj.end(), "Failed to find dataset_files"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id"); | |||
| std::vector<std::string> dataset_files = json_obj["dataset_files"]; | |||
| int64_t num_samples = json_obj["num_samples"]; | |||
| ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]); | |||
| int32_t num_shards = json_obj["num_shards"]; | |||
| int32_t shard_id = json_obj["shard_id"]; | |||
| // default value for cache - to_json function does not have the output | |||
| std::shared_ptr<DatasetCache> cache = nullptr; | |||
| *ds = std::make_shared<TextFileNode>(dataset_files, num_samples, shuffle, num_shards, shard_id, cache); | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::CreateTFRecordDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_files") != json_obj.end(), "Failed to find dataset_files"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("schema") != json_obj.end(), "Failed to find schema"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("columns_list") != json_obj.end(), "Failed to find columns_list"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_equal_rows") != json_obj.end(), "Failed to find shard_equal_rows"); | |||
| std::vector<std::string> dataset_files = json_obj["dataset_files"]; | |||
| std::string schema = json_obj["schema"]; | |||
| std::vector<std::string> columns_list = json_obj["columns_list"]; | |||
| int64_t num_samples = json_obj["num_samples"]; | |||
| ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]); | |||
| int32_t num_shards = json_obj["num_shards"]; | |||
| int32_t shard_id = json_obj["shard_id"]; | |||
| bool shard_equal_rows = json_obj["shard_equal_rows"]; | |||
| // default value for cache - to_json function does not have the output | |||
| std::shared_ptr<DatasetCache> cache = nullptr; | |||
| *ds = std::make_shared<TFRecordNode>(dataset_files, schema, columns_list, num_samples, shuffle, num_shards, shard_id, | |||
| shard_equal_rows, cache); | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::CreateVOCDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("task") != json_obj.end(), "Failed to find task"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("class_indexing") != json_obj.end(), "Failed to find class_indexing"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler"); | |||
| std::string dataset_dir = json_obj["dataset_dir"]; | |||
| std::string task = json_obj["task"]; | |||
| std::string usage = json_obj["usage"]; | |||
| std::map<std::string, int32_t> class_indexing; | |||
| nlohmann::json class_map = json_obj["class_indexing"]; | |||
| for (const auto &class_map_child : class_map) { | |||
| std::string class_ = class_map_child[0]; | |||
| int32_t indexing = class_map_child[1]; | |||
| class_indexing.insert({class_, indexing}); | |||
| } | |||
| bool decode = json_obj["decode"]; | |||
| std::shared_ptr<SamplerObj> sampler; | |||
| RETURN_IF_NOT_OK(ConstructSampler(json_obj["sampler"], &sampler)); | |||
| // default value for cache and extra_metadata - to_json function does not have the output | |||
| std::shared_ptr<DatasetCache> cache = nullptr; | |||
| bool extra_metadata = false; | |||
| *ds = std::make_shared<VOCNode>(dataset_dir, task, usage, class_indexing, decode, sampler, cache, extra_metadata); | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::CreateDatasetNode(nlohmann::json json_obj, std::string op_type, std::shared_ptr<DatasetNode> *ds) { | |||
| if (op_type == kCelebANode) { | |||
| RETURN_IF_NOT_OK(CreateCelebADatasetNode(json_obj, ds)); | |||
| RETURN_IF_NOT_OK(CelebANode::from_json(json_obj, ds)); | |||
| } else if (op_type == kCifar10Node) { | |||
| RETURN_IF_NOT_OK(CreateCifar10DatasetNode(json_obj, ds)); | |||
| RETURN_IF_NOT_OK(Cifar10Node::from_json(json_obj, ds)); | |||
| } else if (op_type == kCifar100Node) { | |||
| RETURN_IF_NOT_OK(CreateCifar100DatasetNode(json_obj, ds)); | |||
| RETURN_IF_NOT_OK(Cifar100Node::from_json(json_obj, ds)); | |||
| } else if (op_type == kCLUENode) { | |||
| RETURN_IF_NOT_OK(CreateCLUEDatasetNode(json_obj, ds)); | |||
| RETURN_IF_NOT_OK(CLUENode::from_json(json_obj, ds)); | |||
| } else if (op_type == kCocoNode) { | |||
| RETURN_IF_NOT_OK(CreateCocoDatasetNode(json_obj, ds)); | |||
| RETURN_IF_NOT_OK(CocoNode::from_json(json_obj, ds)); | |||
| } else if (op_type == kCSVNode) { | |||
| RETURN_IF_NOT_OK(CreateCSVDatasetNode(json_obj, ds)); | |||
| RETURN_IF_NOT_OK(CSVNode::from_json(json_obj, ds)); | |||
| } else if (op_type == kImageFolderNode) { | |||
| RETURN_IF_NOT_OK(CreateImageFolderDatasetNode(json_obj, ds)); | |||
| RETURN_IF_NOT_OK(ImageFolderNode::from_json(json_obj, ds)); | |||
| } else if (op_type == kManifestNode) { | |||
| RETURN_IF_NOT_OK(CreateManifestDatasetNode(json_obj, ds)); | |||
| RETURN_IF_NOT_OK(ManifestNode::from_json(json_obj, ds)); | |||
| } else if (op_type == kMnistNode) { | |||
| RETURN_IF_NOT_OK(CreateMnistDatasetNode(json_obj, ds)); | |||
| RETURN_IF_NOT_OK(MnistNode::from_json(json_obj, ds)); | |||
| } else if (op_type == kTextFileNode) { | |||
| RETURN_IF_NOT_OK(CreateTextFileDatasetNode(json_obj, ds)); | |||
| RETURN_IF_NOT_OK(TextFileNode::from_json(json_obj, ds)); | |||
| } else if (op_type == kTFRecordNode) { | |||
| RETURN_IF_NOT_OK(CreateTFRecordDatasetNode(json_obj, ds)); | |||
| RETURN_IF_NOT_OK(TFRecordNode::from_json(json_obj, ds)); | |||
| } else if (op_type == kVOCNode) { | |||
| RETURN_IF_NOT_OK(CreateVOCDatasetNode(json_obj, ds)); | |||
| RETURN_IF_NOT_OK(VOCNode::from_json(json_obj, ds)); | |||
| } else { | |||
| return Status(StatusCode::kMDUnexpectedError, op_type + " is not supported"); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::CreateBatchOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj, | |||
| std::shared_ptr<DatasetNode> *result) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("batch_size") != json_obj.end(), "Failed to find batch_size"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("drop_remainder") != json_obj.end(), "Failed to find drop_remainder"); | |||
| int32_t batch_size = json_obj["batch_size"]; | |||
| bool drop_remainder = json_obj["drop_remainder"]; | |||
| *result = std::make_shared<BatchNode>(ds, batch_size, drop_remainder); | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::CreateMapOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj, | |||
| std::shared_ptr<DatasetNode> *result) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(), | |||
| "Failed to find num_parallel_workers"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("input_columns") != json_obj.end(), "Failed to find input_columns"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("output_columns") != json_obj.end(), "Failed to find output_columns"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("project_columns") != json_obj.end(), "Failed to find project_columns"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("operations") != json_obj.end(), "Failed to find operations"); | |||
| std::vector<std::string> input_columns = json_obj["input_columns"]; | |||
| std::vector<std::string> output_columns = json_obj["output_columns"]; | |||
| std::vector<std::string> project_columns = json_obj["project_columns"]; | |||
| std::vector<std::shared_ptr<TensorOperation>> operations; | |||
| RETURN_IF_NOT_OK(ConstructTensorOps(json_obj["operations"], &operations)); | |||
| *result = std::make_shared<MapNode>(ds, operations, input_columns, output_columns, project_columns); | |||
| (*result)->SetNumWorkers(json_obj["num_parallel_workers"]); | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::CreateProjectOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj, | |||
| std::shared_ptr<DatasetNode> *result) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("columns") != json_obj.end(), "Failed to find columns"); | |||
| std::vector<std::string> columns = json_obj["columns"]; | |||
| *result = std::make_shared<ProjectNode>(ds, columns); | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::CreateRenameOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj, | |||
| std::shared_ptr<DatasetNode> *result) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("input_columns") != json_obj.end(), "Failed to find input_columns"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("output_columns") != json_obj.end(), "Failed to find output_columns"); | |||
| std::vector<std::string> input_columns = json_obj["input_columns"]; | |||
| std::vector<std::string> output_columns = json_obj["output_columns"]; | |||
| *result = std::make_shared<RenameNode>(ds, input_columns, output_columns); | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::CreateRepeatOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj, | |||
| std::shared_ptr<DatasetNode> *result) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("count") != json_obj.end(), "Failed to find count"); | |||
| int32_t count = json_obj["count"]; | |||
| *result = std::make_shared<RepeatNode>(ds, count); | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::CreateShuffleOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj, | |||
| std::shared_ptr<DatasetNode> *result) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("buffer_size") != json_obj.end(), "Failed to find buffer_size"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("reshuffle_each_epoch") != json_obj.end(), | |||
| "Failed to find reshuffle_each_epoch"); | |||
| int32_t buffer_size = json_obj["buffer_size"]; | |||
| bool reset_every_epoch = json_obj["reshuffle_each_epoch"]; | |||
| *result = std::make_shared<ShuffleNode>(ds, buffer_size, reset_every_epoch); | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::CreateSkipOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj, | |||
| std::shared_ptr<DatasetNode> *result) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("count") != json_obj.end(), "Failed to find count"); | |||
| int32_t count = json_obj["count"]; | |||
| *result = std::make_shared<SkipNode>(ds, count); | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::CreateTransferOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj, | |||
| std::shared_ptr<DatasetNode> *result) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("queue_name") != json_obj.end(), "Failed to find queue_name"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("device_type") != json_obj.end(), "Failed to find device_type"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("device_id") != json_obj.end(), "Failed to find device_id"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("send_epoch_end") != json_obj.end(), "Failed to find send_epoch_end"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("total_batch") != json_obj.end(), "Failed to find total_batch"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("create_data_info_queue") != json_obj.end(), | |||
| "Failed to find create_data_info_queue"); | |||
| std::string queue_name = json_obj["queue_name"]; | |||
| std::string device_type = json_obj["device_type"]; | |||
| int32_t device_id = json_obj["device_id"]; | |||
| bool send_epoch_end = json_obj["send_epoch_end"]; | |||
| int32_t total_batch = json_obj["total_batch"]; | |||
| bool create_data_info_queue = json_obj["create_data_info_queue"]; | |||
| *result = std::make_shared<TransferNode>(ds, queue_name, device_type, device_id, send_epoch_end, total_batch, | |||
| create_data_info_queue); | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::CreateTakeOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj, | |||
| std::shared_ptr<DatasetNode> *result) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("count") != json_obj.end(), "Failed to find count"); | |||
| int32_t count = json_obj["count"]; | |||
| *result = std::make_shared<TakeNode>(ds, count); | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::CreateDatasetOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj, std::string op_type, | |||
| std::shared_ptr<DatasetNode> *result) { | |||
| if (op_type == kBatchNode) { | |||
| RETURN_IF_NOT_OK(CreateBatchOperationNode(ds, json_obj, result)); | |||
| RETURN_IF_NOT_OK(BatchNode::from_json(json_obj, ds, result)); | |||
| } else if (op_type == kMapNode) { | |||
| RETURN_IF_NOT_OK(CreateMapOperationNode(ds, json_obj, result)); | |||
| RETURN_IF_NOT_OK(MapNode::from_json(json_obj, ds, result)); | |||
| } else if (op_type == kProjectNode) { | |||
| RETURN_IF_NOT_OK(CreateProjectOperationNode(ds, json_obj, result)); | |||
| RETURN_IF_NOT_OK(ProjectNode::from_json(json_obj, ds, result)); | |||
| } else if (op_type == kRenameNode) { | |||
| RETURN_IF_NOT_OK(CreateRenameOperationNode(ds, json_obj, result)); | |||
| RETURN_IF_NOT_OK(RenameNode::from_json(json_obj, ds, result)); | |||
| } else if (op_type == kRepeatNode) { | |||
| RETURN_IF_NOT_OK(CreateRepeatOperationNode(ds, json_obj, result)); | |||
| RETURN_IF_NOT_OK(RepeatNode::from_json(json_obj, ds, result)); | |||
| } else if (op_type == kShuffleNode) { | |||
| RETURN_IF_NOT_OK(CreateShuffleOperationNode(ds, json_obj, result)); | |||
| RETURN_IF_NOT_OK(ShuffleNode::from_json(json_obj, ds, result)); | |||
| } else if (op_type == kSkipNode) { | |||
| RETURN_IF_NOT_OK(CreateSkipOperationNode(ds, json_obj, result)); | |||
| RETURN_IF_NOT_OK(SkipNode::from_json(json_obj, ds, result)); | |||
| } else if (op_type == kTransferNode) { | |||
| RETURN_IF_NOT_OK(CreateTransferOperationNode(ds, json_obj, result)); | |||
| RETURN_IF_NOT_OK(TransferNode::from_json(json_obj, ds, result)); | |||
| } else if (op_type == kTakeNode) { | |||
| RETURN_IF_NOT_OK(CreateTakeOperationNode(ds, json_obj, result)); | |||
| RETURN_IF_NOT_OK(TakeNode::from_json(json_obj, ds, result)); | |||
| } else { | |||
| return Status(StatusCode::kMDUnexpectedError, op_type + " operation is not supported"); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::ConstructDistributedSampler(nlohmann::json json_obj, int64_t num_samples, | |||
| std::shared_ptr<SamplerObj> *sampler) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("seed") != json_obj.end(), "Failed to find seed"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("offset") != json_obj.end(), "Failed to find offset"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("even_dist") != json_obj.end(), "Failed to find even_dist"); | |||
| int64_t num_shards = json_obj["num_shards"]; | |||
| int64_t shard_id = json_obj["shard_id"]; | |||
| bool shuffle = json_obj["shuffle"]; | |||
| uint32_t seed = json_obj["seed"]; | |||
| int64_t offset = json_obj["offset"]; | |||
| bool even_dist = json_obj["even_dist"]; | |||
| *sampler = | |||
| std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist); | |||
| if (json_obj.find("child_sampler") != json_obj.end()) { | |||
| std::shared_ptr<SamplerObj> parent_sampler = *sampler; | |||
| RETURN_IF_NOT_OK(ChildSamplerFromJson(json_obj, parent_sampler, sampler)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::ConstructPKSampler(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_val") != json_obj.end(), "Failed to find num_val"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle"); | |||
| int64_t num_val = json_obj["num_val"]; | |||
| bool shuffle = json_obj["shuffle"]; | |||
| *sampler = std::make_shared<PKSamplerObj>(num_val, shuffle, num_samples); | |||
| if (json_obj.find("child_sampler") != json_obj.end()) { | |||
| std::shared_ptr<SamplerObj> parent_sampler = *sampler; | |||
| RETURN_IF_NOT_OK(ChildSamplerFromJson(json_obj, parent_sampler, sampler)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::ConstructRandomSampler(nlohmann::json json_obj, int64_t num_samples, | |||
| std::shared_ptr<SamplerObj> *sampler) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("replacement") != json_obj.end(), "Failed to find replacement"); | |||
| bool replacement = json_obj["replacement"]; | |||
| *sampler = std::make_shared<RandomSamplerObj>(replacement, num_samples); | |||
| if (json_obj.find("child_sampler") != json_obj.end()) { | |||
| std::shared_ptr<SamplerObj> parent_sampler = *sampler; | |||
| RETURN_IF_NOT_OK(ChildSamplerFromJson(json_obj, parent_sampler, sampler)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::ConstructSequentialSampler(nlohmann::json json_obj, int64_t num_samples, | |||
| std::shared_ptr<SamplerObj> *sampler) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("start_index") != json_obj.end(), "Failed to find start_index"); | |||
| int64_t start_index = json_obj["start_index"]; | |||
| *sampler = std::make_shared<SequentialSamplerObj>(start_index, num_samples); | |||
| if (json_obj.find("child_sampler") != json_obj.end()) { | |||
| std::shared_ptr<SamplerObj> parent_sampler = *sampler; | |||
| RETURN_IF_NOT_OK(ChildSamplerFromJson(json_obj, parent_sampler, sampler)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::ConstructSubsetRandomSampler(nlohmann::json json_obj, int64_t num_samples, | |||
| std::shared_ptr<SamplerObj> *sampler) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("indices") != json_obj.end(), "Failed to find indices"); | |||
| std::vector<int64_t> indices = json_obj["indices"]; | |||
| *sampler = std::make_shared<SubsetRandomSamplerObj>(indices, num_samples); | |||
| if (json_obj.find("child_sampler") != json_obj.end()) { | |||
| std::shared_ptr<SamplerObj> parent_sampler = *sampler; | |||
| RETURN_IF_NOT_OK(ChildSamplerFromJson(json_obj, parent_sampler, sampler)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::ConstructWeightedRandomSampler(nlohmann::json json_obj, int64_t num_samples, | |||
| std::shared_ptr<SamplerObj> *sampler) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("replacement") != json_obj.end(), "Failed to find replacement"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("weights") != json_obj.end(), "Failed to find weights"); | |||
| bool replacement = json_obj["replacement"]; | |||
| std::vector<double> weights = json_obj["weights"]; | |||
| *sampler = std::make_shared<WeightedRandomSamplerObj>(weights, num_samples, replacement); | |||
| if (json_obj.find("child_sampler") != json_obj.end()) { | |||
| std::shared_ptr<SamplerObj> parent_sampler = *sampler; | |||
| RETURN_IF_NOT_OK(ChildSamplerFromJson(json_obj, parent_sampler, sampler)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::ConstructSampler(nlohmann::json json_obj, std::shared_ptr<SamplerObj> *sampler) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler_name") != json_obj.end(), "Failed to find sampler_name"); | |||
| int64_t num_samples = json_obj["num_samples"]; | |||
| std::string sampler_name = json_obj["sampler_name"]; | |||
| if (sampler_name == "DistributedSampler") { | |||
| RETURN_IF_NOT_OK(ConstructDistributedSampler(json_obj, num_samples, sampler)); | |||
| RETURN_IF_NOT_OK(DistributedSamplerObj::from_json(json_obj, num_samples, sampler)); | |||
| } else if (sampler_name == "PKSampler") { | |||
| RETURN_IF_NOT_OK(ConstructPKSampler(json_obj, num_samples, sampler)); | |||
| RETURN_IF_NOT_OK(PKSamplerObj::from_json(json_obj, num_samples, sampler)); | |||
| } else if (sampler_name == "RandomSampler") { | |||
| RETURN_IF_NOT_OK(ConstructRandomSampler(json_obj, num_samples, sampler)); | |||
| RETURN_IF_NOT_OK(RandomSamplerObj::from_json(json_obj, num_samples, sampler)); | |||
| } else if (sampler_name == "SequentialSampler") { | |||
| RETURN_IF_NOT_OK(ConstructSequentialSampler(json_obj, num_samples, sampler)); | |||
| RETURN_IF_NOT_OK(SequentialSamplerObj::from_json(json_obj, num_samples, sampler)); | |||
| } else if (sampler_name == "SubsetSampler") { | |||
| RETURN_IF_NOT_OK(SubsetSamplerObj::from_json(json_obj, num_samples, sampler)); | |||
| } else if (sampler_name == "SubsetRandomSampler") { | |||
| RETURN_IF_NOT_OK(ConstructSubsetRandomSampler(json_obj, num_samples, sampler)); | |||
| RETURN_IF_NOT_OK(SubsetRandomSamplerObj::from_json(json_obj, num_samples, sampler)); | |||
| } else if (sampler_name == "WeightedRandomSampler") { | |||
| RETURN_IF_NOT_OK(ConstructWeightedRandomSampler(json_obj, num_samples, sampler)); | |||
| RETURN_IF_NOT_OK(WeightedRandomSamplerObj::from_json(json_obj, num_samples, sampler)); | |||
| } else { | |||
| return Status(StatusCode::kMDUnexpectedError, sampler_name + "Sampler is not supported"); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::ChildSamplerFromJson(nlohmann::json json_obj, std::shared_ptr<SamplerObj> parent_sampler, | |||
| std::shared_ptr<SamplerObj> *sampler) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("child_sampler") != json_obj.end(), "Failed to find child_sampler"); | |||
| for (nlohmann::json child : json_obj["child_sampler"]) { | |||
| std::shared_ptr<SamplerObj> child_sampler; | |||
| RETURN_IF_NOT_OK(ConstructSampler(child, &child_sampler)); | |||
| parent_sampler.get()->AddChildSampler(child_sampler); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::BoundingBoxAugmentFromJson(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("transform") != op_params.end(), "Failed to find transform"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("ratio") != op_params.end(), "Failed to find ratio"); | |||
| std::vector<std::shared_ptr<TensorOperation>> transforms; | |||
| std::vector<nlohmann::json> json_operations = {}; | |||
| json_operations.push_back(op_params["transform"]); | |||
| RETURN_IF_NOT_OK(ConstructTensorOps(json_operations, &transforms)); | |||
| float ratio = op_params["ratio"]; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(transforms.size() == 1, | |||
| "Expect size one of transforms parameter, but got:" + std::to_string(transforms.size())); | |||
| *operation = std::make_shared<vision::BoundingBoxAugmentOperation>(transforms[0], ratio); | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::RandomSelectSubpolicyFromJson(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("policy") != op_params.end(), "Failed to find policy"); | |||
| nlohmann::json policy_json = op_params["policy"]; | |||
| std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> policy; | |||
| std::vector<std::pair<std::shared_ptr<TensorOperation>, double>> policy_items; | |||
| for (nlohmann::json item : policy_json) { | |||
| for (nlohmann::json item_pair : item) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(item_pair.find("prob") != item_pair.end(), "Failed to find prob"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(item_pair.find("tensor_op") != item_pair.end(), "Failed to find tensor_op"); | |||
| std::vector<std::shared_ptr<TensorOperation>> operations; | |||
| std::pair<std::shared_ptr<TensorOperation>, double> policy_pair; | |||
| std::shared_ptr<TensorOperation> operation; | |||
| nlohmann::json tensor_op_json; | |||
| double prob = item_pair["prob"]; | |||
| tensor_op_json.push_back(item_pair["tensor_op"]); | |||
| RETURN_IF_NOT_OK(ConstructTensorOps(tensor_op_json, &operations)); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(operations.size() == 1, "There should be only 1 tensor operation"); | |||
| policy_pair = std::make_pair(operations[0], prob); | |||
| policy_items.push_back(policy_pair); | |||
| } | |||
| policy.push_back(policy_items); | |||
| } | |||
| *operation = std::make_shared<vision::RandomSelectSubpolicyOperation>(policy); | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::UniformAugFromJson(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("transforms") != op_params.end(), "Failed to find transforms"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("num_ops") != op_params.end(), "Failed to find num_ops"); | |||
| std::vector<std::shared_ptr<TensorOperation>> transforms = {}; | |||
| RETURN_IF_NOT_OK(ConstructTensorOps(op_params["transforms"], &transforms)); | |||
| int32_t num_ops = op_params["num_ops"]; | |||
| *operation = std::make_shared<vision::UniformAugOperation>(transforms, num_ops); | |||
| return Status::OK(); | |||
| } | |||
| Status Serdes::ConstructTensorOps(nlohmann::json operations, std::vector<std::shared_ptr<TensorOperation>> *result) { | |||
| Status Serdes::ConstructTensorOps(nlohmann::json json_obj, std::vector<std::shared_ptr<TensorOperation>> *result) { | |||
| std::vector<std::shared_ptr<TensorOperation>> output; | |||
| for (auto op : operations) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(op.find("is_python_front_end_op") == op.end(), | |||
| for (nlohmann::json item : json_obj) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(item.find("is_python_front_end_op") == item.end(), | |||
| "python operation is not yet supported"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(op.find("tensor_op_name") != op.end(), "Failed to find tensor_op_name"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(op.find("tensor_op_params") != op.end(), "Failed to find tensor_op_params"); | |||
| std::string op_name = op["tensor_op_name"]; | |||
| nlohmann::json op_params = op["tensor_op_params"]; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(item.find("tensor_op_name") != item.end(), "Failed to find tensor_op_name"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(item.find("tensor_op_params") != item.end(), "Failed to find tensor_op_params"); | |||
| std::string op_name = item["tensor_op_name"]; | |||
| nlohmann::json op_params = item["tensor_op_params"]; | |||
| std::shared_ptr<TensorOperation> operation = nullptr; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(func_ptr_.find(op_name) != func_ptr_.end(), "Failed to find " + op_name); | |||
| RETURN_IF_NOT_OK(func_ptr_[op_name](op_params, &operation)); | |||
| @@ -716,7 +229,7 @@ Serdes::InitializeFuncPtr() { | |||
| std::map<std::string, Status (*)(nlohmann::json json_obj, std::shared_ptr<TensorOperation> * operation)> ops_ptr; | |||
| ops_ptr[vision::kAffineOperation] = &(vision::AffineOperation::from_json); | |||
| ops_ptr[vision::kAutoContrastOperation] = &(vision::AutoContrastOperation::from_json); | |||
| ops_ptr[vision::kBoundingBoxAugmentOperation] = &(BoundingBoxAugmentFromJson); | |||
| ops_ptr[vision::kBoundingBoxAugmentOperation] = &(vision::BoundingBoxAugmentOperation::from_json); | |||
| ops_ptr[vision::kCenterCropOperation] = &(vision::CenterCropOperation::from_json); | |||
| ops_ptr[vision::kCropOperation] = &(vision::CropOperation::from_json); | |||
| ops_ptr[vision::kCutMixBatchOperation] = &(vision::CutMixBatchOperation::from_json); | |||
| @@ -745,7 +258,7 @@ Serdes::InitializeFuncPtr() { | |||
| ops_ptr[vision::kRandomResizedCropOperation] = &(vision::RandomResizedCropOperation::from_json); | |||
| ops_ptr[vision::kRandomResizedCropWithBBoxOperation] = &(vision::RandomResizedCropWithBBoxOperation::from_json); | |||
| ops_ptr[vision::kRandomRotationOperation] = &(vision::RandomRotationOperation::from_json); | |||
| ops_ptr[vision::kRandomSelectSubpolicyOperation] = &(RandomSelectSubpolicyFromJson); | |||
| ops_ptr[vision::kRandomSelectSubpolicyOperation] = &(vision::RandomSelectSubpolicyOperation::from_json); | |||
| ops_ptr[vision::kRandomSharpnessOperation] = &(vision::RandomSharpnessOperation::from_json); | |||
| ops_ptr[vision::kRandomSolarizeOperation] = &(vision::RandomSolarizeOperation::from_json); | |||
| ops_ptr[vision::kRandomVerticalFlipOperation] = &(vision::RandomVerticalFlipOperation::from_json); | |||
| @@ -766,7 +279,7 @@ Serdes::InitializeFuncPtr() { | |||
| &(vision::SoftDvppDecodeRandomCropResizeJpegOperation::from_json); | |||
| ops_ptr[vision::kSoftDvppDecodeResizeJpegOperation] = &(vision::SoftDvppDecodeResizeJpegOperation::from_json); | |||
| ops_ptr[vision::kSwapRedBlueOperation] = &(vision::SwapRedBlueOperation::from_json); | |||
| ops_ptr[vision::kUniformAugOperation] = &(UniformAugFromJson); | |||
| ops_ptr[vision::kUniformAugOperation] = &(vision::UniformAugOperation::from_json); | |||
| ops_ptr[vision::kVerticalFlipOperation] = &(vision::VerticalFlipOperation::from_json); | |||
| ops_ptr[transforms::kFillOperation] = &(transforms::FillOperation::from_json); | |||
| ops_ptr[transforms::kOneHotOperation] = &(transforms::OneHotOperation::from_json); | |||
| @@ -159,6 +159,18 @@ class Serdes { | |||
| /// \return Status The status code returned | |||
| static Status ConstructPipeline(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| /// \brief Helper functions for creating sampler, separate different samplers and call the related function | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[out] sampler Deserialized sampler | |||
| /// \return Status The status code returned | |||
| static Status ConstructSampler(nlohmann::json json_obj, std::shared_ptr<SamplerObj> *sampler); | |||
| /// \brief helper function to construct tensor operations | |||
| /// \param[in] json_obj json object of operations to be deserilized | |||
| /// \param[out] vector of tensor operation pointer | |||
| /// \return Status The status code returned | |||
| static Status ConstructTensorOps(nlohmann::json json_obj, std::vector<std::shared_ptr<TensorOperation>> *result); | |||
| protected: | |||
| /// \brief Helper function to save JSON to a file | |||
| /// \param[in] json_string The JSON string to be saved to the file | |||
| @@ -189,91 +201,6 @@ class Serdes { | |||
| static Status CreateDatasetOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj, | |||
| std::string op_type, std::shared_ptr<DatasetNode> *result); | |||
| /// \brief Helper functions for creating sampler, separate different samplers and call the related function | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[out] sampler Deserialized sampler | |||
| /// \return Status The status code returned | |||
| static Status ConstructSampler(nlohmann::json json_obj, std::shared_ptr<SamplerObj> *sampler); | |||
| /// \brief helper function to construct tensor operations | |||
| /// \param[in] operations operations to be deserilized | |||
| /// \param[out] vector of tensor operation pointer | |||
| /// \return Status The status code returned | |||
| static Status ConstructTensorOps(nlohmann::json operations, std::vector<std::shared_ptr<TensorOperation>> *result); | |||
| /// \brief Helper functions for different datasets | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[out] ds Deserialized dataset | |||
| /// \return Status The status code returned | |||
| static Status CreateCelebADatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| static Status CreateCifar10DatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| static Status CreateCifar100DatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| static Status CreateCLUEDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| static Status CreateCocoDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| static Status CreateCSVDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| static Status CreateImageFolderDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| static Status CreateManifestDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| static Status CreateMnistDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| static Status CreateTextFileDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| static Status CreateTFRecordDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| static Status CreateVOCDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); | |||
| /// \brief Helper functions for different operations | |||
| /// \param[in] ds dataset node constructed | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[out] result Deserialized dataset after the operation | |||
| /// \return Status The status code returned | |||
| static Status CreateBatchOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj, | |||
| std::shared_ptr<DatasetNode> *result); | |||
| static Status CreateMapOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj, | |||
| std::shared_ptr<DatasetNode> *result); | |||
| static Status CreateProjectOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj, | |||
| std::shared_ptr<DatasetNode> *result); | |||
| static Status CreateRenameOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj, | |||
| std::shared_ptr<DatasetNode> *result); | |||
| static Status CreateRepeatOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj, | |||
| std::shared_ptr<DatasetNode> *result); | |||
| static Status CreateShuffleOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj, | |||
| std::shared_ptr<DatasetNode> *result); | |||
| static Status CreateSkipOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj, | |||
| std::shared_ptr<DatasetNode> *result); | |||
| static Status CreateTransferOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj, | |||
| std::shared_ptr<DatasetNode> *result); | |||
| static Status CreateTakeOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj, | |||
| std::shared_ptr<DatasetNode> *result); | |||
| /// \brief Helper functions for different samplers | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[out] sampler Deserialized sampler | |||
| /// \return Status The status code returned | |||
| static Status ConstructDistributedSampler(nlohmann::json json_obj, int64_t num_samples, | |||
| std::shared_ptr<SamplerObj> *sampler); | |||
| static Status ConstructPKSampler(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler); | |||
| static Status ConstructRandomSampler(nlohmann::json json_obj, int64_t num_samples, | |||
| std::shared_ptr<SamplerObj> *sampler); | |||
| static Status ConstructSequentialSampler(nlohmann::json json_obj, int64_t num_samples, | |||
| std::shared_ptr<SamplerObj> *sampler); | |||
| static Status ConstructSubsetRandomSampler(nlohmann::json json_obj, int64_t num_samples, | |||
| std::shared_ptr<SamplerObj> *sampler); | |||
| static Status ConstructWeightedRandomSampler(nlohmann::json json_obj, int64_t num_samples, | |||
| std::shared_ptr<SamplerObj> *sampler); | |||
| /// \brief Helper functions to construct children samplers | |||
| /// \param[in] json_obj The JSON object to be deserialized | |||
| /// \param[in] parent_sampler given parent sampler | |||
| /// \param[out] sampler sampler constructed - parent sampler with children samplers added | |||
| /// \return Status The status code returned | |||
| static Status ChildSamplerFromJson(nlohmann::json json_obj, std::shared_ptr<SamplerObj> parent_sampler, | |||
| std::shared_ptr<SamplerObj> *sampler); | |||
| /// \brief Helper functions for vision operations, which requires tensor operations as input | |||
| /// \param[in] op_params operation parameters for the operation | |||
| /// \param[out] operation deserialized operation | |||
| /// \return Status The status code returned | |||
| static Status BoundingBoxAugmentFromJson(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation); | |||
| static Status RandomSelectSubpolicyFromJson(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation); | |||
| static Status UniformAugFromJson(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation); | |||
| /// \brief Helper function to map the function pointers | |||
| /// \return map of key to function pointer | |||
| static std::map<std::string, Status (*)(nlohmann::json json_obj, std::shared_ptr<TensorOperation> *operation)> | |||
| @@ -18,6 +18,7 @@ | |||
| #include "minddata/dataset/kernels/ir/vision/bounding_box_augment_ir.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/serdes.h" | |||
| #include "minddata/dataset/kernels/image/bounding_box_augment_op.h" | |||
| #endif | |||
| @@ -56,6 +57,20 @@ Status BoundingBoxAugmentOperation::to_json(nlohmann::json *out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| Status BoundingBoxAugmentOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("transform") != op_params.end(), "Failed to find transform"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("ratio") != op_params.end(), "Failed to find ratio"); | |||
| std::vector<std::shared_ptr<TensorOperation>> transforms; | |||
| std::vector<nlohmann::json> json_operations = {}; | |||
| json_operations.push_back(op_params["transform"]); | |||
| RETURN_IF_NOT_OK(Serdes::ConstructTensorOps(json_operations, &transforms)); | |||
| float ratio = op_params["ratio"]; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(transforms.size() == 1, | |||
| "Expect size one of transforms parameter, but got:" + std::to_string(transforms.size())); | |||
| *operation = std::make_shared<vision::BoundingBoxAugmentOperation>(transforms[0], ratio); | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| } // namespace vision | |||
| } // namespace dataset | |||
| @@ -49,6 +49,8 @@ class BoundingBoxAugmentOperation : public TensorOperation { | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation); | |||
| private: | |||
| std::shared_ptr<TensorOperation> transform_; | |||
| float ratio_; | |||
| @@ -18,6 +18,7 @@ | |||
| #include "minddata/dataset/kernels/ir/vision/random_select_subpolicy_ir.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/serdes.h" | |||
| #include "minddata/dataset/kernels/image/random_select_subpolicy_op.h" | |||
| #endif | |||
| @@ -100,6 +101,33 @@ Status RandomSelectSubpolicyOperation::to_json(nlohmann::json *out_json) { | |||
| (*out_json)["policy"] = policy_tensor_ops; | |||
| return Status::OK(); | |||
| } | |||
| Status RandomSelectSubpolicyOperation::from_json(nlohmann::json op_params, | |||
| std::shared_ptr<TensorOperation> *operation) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("policy") != op_params.end(), "Failed to find policy"); | |||
| nlohmann::json policy_json = op_params["policy"]; | |||
| std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> policy; | |||
| std::vector<std::pair<std::shared_ptr<TensorOperation>, double>> policy_items; | |||
| for (nlohmann::json item : policy_json) { | |||
| for (nlohmann::json item_pair : item) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(item_pair.find("prob") != item_pair.end(), "Failed to find prob"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(item_pair.find("tensor_op") != item_pair.end(), "Failed to find tensor_op"); | |||
| std::vector<std::shared_ptr<TensorOperation>> operations; | |||
| std::pair<std::shared_ptr<TensorOperation>, double> policy_pair; | |||
| std::shared_ptr<TensorOperation> operation; | |||
| nlohmann::json tensor_op_json; | |||
| double prob = item_pair["prob"]; | |||
| tensor_op_json.push_back(item_pair["tensor_op"]); | |||
| RETURN_IF_NOT_OK(Serdes::ConstructTensorOps(tensor_op_json, &operations)); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(operations.size() == 1, "There should be only 1 tensor operation"); | |||
| policy_pair = std::make_pair(operations[0], prob); | |||
| policy_items.push_back(policy_pair); | |||
| } | |||
| policy.push_back(policy_items); | |||
| } | |||
| *operation = std::make_shared<vision::RandomSelectSubpolicyOperation>(policy); | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| } // namespace vision | |||
| } // namespace dataset | |||
| @@ -50,6 +50,8 @@ class RandomSelectSubpolicyOperation : public TensorOperation { | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation); | |||
| private: | |||
| std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> policy_; | |||
| }; | |||
| @@ -18,6 +18,7 @@ | |||
| #include "minddata/dataset/kernels/ir/vision/uniform_aug_ir.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/serdes.h" | |||
| #include "minddata/dataset/kernels/image/uniform_aug_op.h" | |||
| #endif | |||
| @@ -74,6 +75,16 @@ Status UniformAugOperation::to_json(nlohmann::json *out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| Status UniformAugOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("transforms") != op_params.end(), "Failed to find transforms"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("num_ops") != op_params.end(), "Failed to find num_ops"); | |||
| std::vector<std::shared_ptr<TensorOperation>> transforms = {}; | |||
| RETURN_IF_NOT_OK(Serdes::ConstructTensorOps(op_params["transforms"], &transforms)); | |||
| int32_t num_ops = op_params["num_ops"]; | |||
| *operation = std::make_shared<vision::UniformAugOperation>(transforms, num_ops); | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| } // namespace vision | |||
| } // namespace dataset | |||
| @@ -49,6 +49,8 @@ class UniformAugOperation : public TensorOperation { | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation); | |||
| private: | |||
| std::vector<std::shared_ptr<TensorOperation>> transforms_; | |||
| int32_t num_ops_; | |||
| @@ -462,6 +462,7 @@ TEST_F(MindDataTestDeserialize, TestDeserializeFill) { | |||
| std::shared_ptr<TensorOperation> operation2 = std::make_shared<text::ToNumberOperation>("int32_t"); | |||
| std::vector<std::shared_ptr<TensorOperation>> ops = {operation1, operation2}; | |||
| ds = std::make_shared<MapNode>(ds, ops); | |||
| ds = std::make_shared<TransferNode>(ds, "queue", "type", 1, true, 10, true); | |||
| compare_dataset(ds); | |||
| } | |||
| @@ -482,3 +483,19 @@ TEST_F(MindDataTestDeserialize, TestDeserializeTensor) { | |||
| json_ss1 << json_obj1; | |||
| EXPECT_EQ(json_ss.str(), json_ss1.str()); | |||
| } | |||
| // Helper function to get the session id from SESSION_ID env variable | |||
| Status GetSessionFromEnv(session_id_type *session_id); | |||
| TEST_F(MindDataTestDeserialize, DISABLED_TestDeserializeCache) { | |||
| MS_LOG(INFO) << "Doing MindDataTestDeserialize-Cache."; | |||
| std::string data_dir = "./data/dataset/testCache"; | |||
| std::string usage = "all"; | |||
| session_id_type env_session; | |||
| ASSERT_TRUE(GetSessionFromEnv(&env_session)); | |||
| std::shared_ptr<DatasetCache> some_cache = CreateDatasetCache(env_session, 0, false, "127.0.0.1", 50052, 1, 1); | |||
| std::shared_ptr<SamplerObj> sampler = std::make_shared<SequentialSamplerObj>(0, 10); | |||
| std::shared_ptr<DatasetNode> ds = std::make_shared<Cifar10Node>(data_dir, usage, sampler, some_cache); | |||
| compare_dataset(ds); | |||
| } | |||