| @@ -526,6 +526,11 @@ std::shared_ptr<TensorOp> CenterCropOperation::Build() { | |||
| return tensor_op; | |||
| } | |||
| Status CenterCropOperation::to_json(nlohmann::json *out_json) { | |||
| (*out_json)["size"] = size_; | |||
| return Status::OK(); | |||
| } | |||
| // CropOperation. | |||
| CropOperation::CropOperation(std::vector<int32_t> coordinates, std::vector<int32_t> size) | |||
| : coordinates_(coordinates), size_(size) {} | |||
| @@ -638,6 +643,11 @@ Status DecodeOperation::ValidateParams() { return Status::OK(); } | |||
| std::shared_ptr<TensorOp> DecodeOperation::Build() { return std::make_shared<DecodeOp>(rgb_); } | |||
| Status DecodeOperation::to_json(nlohmann::json *out_json) { | |||
| (*out_json)["rgb"] = rgb_; | |||
| return Status::OK(); | |||
| } | |||
| // EqualizeOperation | |||
| Status EqualizeOperation::ValidateParams() { return Status::OK(); } | |||
| @@ -801,6 +811,14 @@ std::shared_ptr<TensorOp> NormalizeOperation::Build() { | |||
| return std::make_shared<NormalizeOp>(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2]); | |||
| } | |||
| Status NormalizeOperation::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["mean"] = mean_; | |||
| args["std"] = std_; | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| // NormalizePadOperation | |||
| NormalizePadOperation::NormalizePadOperation(const std::vector<float> &mean, const std::vector<float> &std, | |||
| @@ -893,6 +911,15 @@ std::shared_ptr<TensorOp> PadOperation::Build() { | |||
| return tensor_op; | |||
| } | |||
| Status PadOperation::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["padding"] = padding_; | |||
| args["fill_value"] = fill_value_; | |||
| args["padding_mode"] = padding_mode_; | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| // RandomAffineOperation | |||
| RandomAffineOperation::RandomAffineOperation(const std::vector<float_t> °rees, | |||
| const std::vector<float_t> &translate_range, | |||
| @@ -1188,6 +1215,16 @@ std::shared_ptr<TensorOp> RandomColorAdjustOperation::Build() { | |||
| return tensor_op; | |||
| } | |||
| Status RandomColorAdjustOperation::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["brightness"] = brightness_; | |||
| args["contrast"] = contrast_; | |||
| args["saturation"] = saturation_; | |||
| args["hue"] = hue_; | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| // RandomCropOperation | |||
| RandomCropOperation::RandomCropOperation(std::vector<int32_t> size, std::vector<int32_t> padding, bool pad_if_needed, | |||
| std::vector<uint8_t> fill_value, BorderType padding_mode) | |||
| @@ -1261,6 +1298,17 @@ std::shared_ptr<TensorOp> RandomCropOperation::Build() { | |||
| return tensor_op; | |||
| } | |||
| Status RandomCropOperation::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["size"] = size_; | |||
| args["padding"] = padding_; | |||
| args["pad_if_needed"] = pad_if_needed_; | |||
| args["fill_value"] = fill_value_; | |||
| args["padding_mode"] = padding_mode_; | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| // RandomCropDecodeResizeOperation | |||
| RandomCropDecodeResizeOperation::RandomCropDecodeResizeOperation(std::vector<int32_t> size, std::vector<float> scale, | |||
| std::vector<float> ratio, | |||
| @@ -1735,6 +1783,17 @@ std::shared_ptr<TensorOp> RandomRotationOperation::Build() { | |||
| return tensor_op; | |||
| } | |||
| Status RandomRotationOperation::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["degrees"] = degrees_; | |||
| args["interpolation_mode"] = interpolation_mode_; | |||
| args["expand"] = expand_; | |||
| args["center"] = center_; | |||
| args["fill_value"] = fill_value_; | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| // RandomSelectSubpolicyOperation. | |||
| RandomSelectSubpolicyOperation::RandomSelectSubpolicyOperation( | |||
| std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> policy) | |||
| @@ -1889,6 +1948,14 @@ std::shared_ptr<TensorOp> RescaleOperation::Build() { | |||
| return tensor_op; | |||
| } | |||
| Status RescaleOperation::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["rescale"] = rescale_; | |||
| args["shift"] = shift_; | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| // ResizeOperation | |||
| ResizeOperation::ResizeOperation(std::vector<int32_t> size, InterpolationMode interpolation) | |||
| @@ -1920,6 +1987,14 @@ std::shared_ptr<TensorOp> ResizeOperation::Build() { | |||
| return std::make_shared<ResizeOp>(height, width, interpolation_); | |||
| } | |||
| Status ResizeOperation::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["size"] = size_; | |||
| args["interpolation"] = interpolation_; | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| // RotateOperation | |||
| RotateOperation::RotateOperation() { rotate_op = std::make_shared<RotateOp>(0); } | |||
| @@ -44,18 +44,5 @@ Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr<Data | |||
| return Status::OK(); | |||
| } | |||
| Status DatasetCacheImpl::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["session_id"] = session_id_; | |||
| args["cache_memory_size"] = cache_mem_sz_; | |||
| args["spill"] = spill_; | |||
| if (hostname_) args["hostname"] = hostname_.value(); | |||
| if (port_) args["port"] = port_.value(); | |||
| if (num_connections_) args["num_connections"] = num_connections_.value(); | |||
| if (prefetch_sz_) args["prefetch_size"] = prefetch_sz_.value(); | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -60,8 +60,6 @@ class DatasetCacheImpl : public DatasetCache { | |||
| ~DatasetCacheImpl() = default; | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| std::shared_ptr<CacheClient> cache_client_; | |||
| session_id_type session_id_; | |||
| @@ -36,5 +36,15 @@ Status PreBuiltDatasetCache::CreateCacheOp(int32_t num_workers, std::shared_ptr< | |||
| return Status::OK(); | |||
| } | |||
| Status PreBuiltDatasetCache::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["session_id"] = cache_client_->session_id(); | |||
| args["cache_memory_size"] = cache_client_->GetCacheMemSz(); | |||
| args["spill"] = cache_client_->isSpill(); | |||
| args["num_connections"] = cache_client_->GetNumConnections(); | |||
| args["prefetch_size"] = cache_client_->GetPrefetchSize(); | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -42,6 +42,8 @@ class PreBuiltDatasetCache : public DatasetCache { | |||
| Status ValidateParams() override { return Status::OK(); } | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| std::shared_ptr<CacheClient> cache_client_; | |||
| }; | |||
| @@ -63,6 +63,15 @@ class BucketBatchByLengthNode : public DatasetNode { | |||
| bool IsSizeDefined() override { return false; }; | |||
| /// \brief Getter functions | |||
| const std::vector<std::string> &ColumnNames() const { return column_names_; } | |||
| const std::vector<int32_t> &BucketBoundaries() const { return bucket_boundaries_; } | |||
| const std::vector<int32_t> &BucketBatchSizes() const { return bucket_batch_sizes_; } | |||
| const std::shared_ptr<TensorOp> &ElementLengthFunction() const { return element_length_function_; } | |||
| const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &PadInfo() const { return pad_info_; } | |||
| bool PadToBucketBoundary() const { return pad_to_bucket_boundary_; } | |||
| bool DropRemainder() const { return drop_remainder_; } | |||
| private: | |||
| std::vector<std::string> column_names_; | |||
| std::vector<int32_t> bucket_boundaries_; | |||
| @@ -72,6 +72,14 @@ class BuildSentenceVocabNode : public DatasetNode { | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(IRNodePass *const p, bool *const modified) override; | |||
| /// \brief Getter functions | |||
| const std::shared_ptr<SentencePieceVocab> &GetVocab() const { return vocab_; } | |||
| const std::vector<std::string> &ColNames() const { return col_names_; } | |||
| int32_t VocabSize() const { return vocab_size_; } | |||
| float CharacterCoverage() const { return character_coverage_; } | |||
| SentencePieceModel ModelType() const { return model_type_; } | |||
| const std::unordered_map<std::string, std::string> &Params() const { return params_; } | |||
| private: | |||
| std::shared_ptr<SentencePieceVocab> vocab_; | |||
| std::vector<std::string> col_names_; | |||
| @@ -70,6 +70,14 @@ class BuildVocabNode : public DatasetNode { | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(IRNodePass *const p, bool *const modified) override; | |||
| /// \brief Getter functions | |||
| const std::shared_ptr<Vocab> &GetVocab() const { return vocab_; } | |||
| const std::vector<std::string> &Columns() const { return columns_; } | |||
| const std::pair<int64_t, int64_t> &FreqRange() const { return freq_range_; } | |||
| int64_t TopK() const { return top_k_; } | |||
| const std::vector<std::string> &SpecialTokens() const { return special_tokens_; } | |||
| bool SpecialFirst() const { return special_first_; } | |||
| private: | |||
| std::shared_ptr<Vocab> vocab_; | |||
| std::vector<std::string> columns_; | |||
| @@ -61,6 +61,10 @@ class ConcatNode : public DatasetNode { | |||
| bool IsSizeDefined() override { return false; } | |||
| /// \brief Getter functions | |||
| const std::vector<std::pair<int, int>> &ChildrenFlagAndNums() const { return children_flag_and_nums_; } | |||
| const std::vector<std::pair<int, int>> &ChildrenStartEndIndex() const { return children_start_end_index_; } | |||
| private: | |||
| std::shared_ptr<SamplerObj> sampler_; | |||
| std::vector<std::pair<int, int>> children_flag_and_nums_; | |||
| @@ -73,5 +73,13 @@ Status FilterNode::AcceptAfter(IRNodePass *const p, bool *const modified) { | |||
| return p->VisitAfter(shared_from_base<FilterNode>(), modified); | |||
| } | |||
| Status FilterNode::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["input_columns"] = input_columns_; | |||
| args["num_parallel_workers"] = num_workers_; | |||
| args["predicate"] = "pyfunc"; | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -70,6 +70,15 @@ class FilterNode : public DatasetNode { | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(IRNodePass *const p, bool *const modified) override; | |||
| /// \brief Getter functions | |||
| const std::shared_ptr<TensorOp> &Predicate() const { return predicate_; } | |||
| const std::vector<std::string> &InputColumns() const { return input_columns_; } | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| std::shared_ptr<TensorOp> predicate_; | |||
| std::vector<std::string> input_columns_; | |||
| @@ -138,8 +138,8 @@ Status MapNode::to_json(nlohmann::json *out_json) { | |||
| std::vector<nlohmann::json> ops; | |||
| std::vector<int32_t> cbs; | |||
| nlohmann::json op_args; | |||
| for (auto op : operations_) { | |||
| nlohmann::json op_args; | |||
| RETURN_IF_NOT_OK(op->to_json(&op_args)); | |||
| op_args["tensor_op_name"] = op->Name(); | |||
| ops.push_back(op_args); | |||
| @@ -57,5 +57,11 @@ Status ProjectNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_op | |||
| return Status::OK(); | |||
| } | |||
| Status ProjectNode::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["columns"] = columns_; | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -55,6 +55,14 @@ class ProjectNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| /// \brief Getter functions | |||
| const std::vector<std::string> &Columns() const { return columns_; } | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| std::vector<std::string> columns_; | |||
| }; | |||
| @@ -29,7 +29,7 @@ namespace dataset { | |||
| class RootNode : public DatasetNode { | |||
| public: | |||
| /// \brief Constructor | |||
| RootNode() : DatasetNode() {} | |||
| RootNode() : DatasetNode(), num_epochs_(0) {} | |||
| /// \brief Constructor | |||
| explicit RootNode(std::shared_ptr<DatasetNode> child); | |||
| @@ -83,5 +83,12 @@ Status SkipNode::AcceptAfter(IRNodePass *const p, bool *const modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<SkipNode>(), modified); | |||
| } | |||
| Status SkipNode::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["count"] = skip_count_; | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -80,6 +80,14 @@ class SkipNode : public DatasetNode { | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(IRNodePass *const p, bool *const modified) override; | |||
| /// \brief Getter functions | |||
| int32_t SkipCount() const { return skip_count_; } | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| int32_t skip_count_; | |||
| }; | |||
| @@ -61,6 +61,12 @@ class AlbumNode : public MappableSourceNode { | |||
| /// \return Status Status::OK() if get shard id successfully | |||
| Status GetShardId(int32_t *shard_id) override; | |||
| /// \brief Getter functions | |||
| const std::string &DatasetDir() const { return dataset_dir_; } | |||
| const std::string &SchemaPath() const { return schema_path_; } | |||
| const std::vector<std::string> &ColumnNames() const { return column_names_; } | |||
| bool Decode() const { return decode_; } | |||
| private: | |||
| std::string dataset_dir_; | |||
| std::string schema_path_; | |||
| @@ -144,5 +144,22 @@ Status CelebANode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size | |||
| return Status::OK(); | |||
| } | |||
| Status CelebANode::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args, sampler_args; | |||
| RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); | |||
| args["sampler"] = sampler_args; | |||
| args["num_parallel_workers"] = num_workers_; | |||
| args["dataset_dir"] = dataset_dir_; | |||
| args["decode"] = decode_; | |||
| args["extensions"] = extensions_; | |||
| args["usage"] = usage_; | |||
| if (cache_ != nullptr) { | |||
| nlohmann::json cache_args; | |||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||
| args["cache"] = cache_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -71,6 +71,17 @@ class CelebANode : public MappableSourceNode { | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| /// \brief Getter functions | |||
| const std::string &DatasetDir() const { return dataset_dir_; } | |||
| const std::string &Usage() const { return usage_; } | |||
| bool Decode() const { return decode_; } | |||
| const std::set<std::string> &Extensions() const { return extensions_; } | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| std::string dataset_dir_; | |||
| std::string usage_; | |||
| @@ -95,5 +95,20 @@ Status Cifar100Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si | |||
| return Status::OK(); | |||
| } | |||
| Status Cifar100Node::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args, sampler_args; | |||
| RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); | |||
| args["sampler"] = sampler_args; | |||
| args["num_parallel_workers"] = num_workers_; | |||
| args["dataset_dir"] = dataset_dir_; | |||
| args["usage"] = usage_; | |||
| if (cache_ != nullptr) { | |||
| nlohmann::json cache_args; | |||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||
| args["cache"] = cache_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -69,6 +69,15 @@ class Cifar100Node : public MappableSourceNode { | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| /// \brief Getter functions | |||
| const std::string &DatasetDir() const { return dataset_dir_; } | |||
| const std::string &Usage() const { return usage_; } | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| std::string dataset_dir_; | |||
| std::string usage_; | |||
| @@ -93,5 +93,20 @@ Status Cifar10Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz | |||
| return Status::OK(); | |||
| } | |||
| Status Cifar10Node::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args, sampler_args; | |||
| RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); | |||
| args["sampler"] = sampler_args; | |||
| args["num_parallel_workers"] = num_workers_; | |||
| args["dataset_dir"] = dataset_dir_; | |||
| args["usage"] = usage_; | |||
| if (cache_ != nullptr) { | |||
| nlohmann::json cache_args; | |||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||
| args["cache"] = cache_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -69,6 +69,15 @@ class Cifar10Node : public MappableSourceNode { | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| /// \brief Getter functions | |||
| const std::string &DatasetDir() const { return dataset_dir_; } | |||
| const std::string &Usage() const { return usage_; } | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| std::string dataset_dir_; | |||
| std::string usage_; | |||
| @@ -252,5 +252,23 @@ Status CLUENode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_g | |||
| return Status::OK(); | |||
| } | |||
| Status CLUENode::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["num_parallel_workers"] = num_workers_; | |||
| args["dataset_dir"] = dataset_files_; | |||
| args["task"] = task_; | |||
| args["usage"] = usage_; | |||
| args["num_samples"] = num_samples_; | |||
| args["shuffle"] = shuffle_; | |||
| args["num_shards"] = num_shards_; | |||
| args["shard_id"] = shard_id_; | |||
| if (cache_ != nullptr) { | |||
| nlohmann::json cache_args; | |||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||
| args["cache"] = cache_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -71,6 +71,20 @@ class CLUENode : public NonMappableSourceNode { | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| /// \brief Getter functions | |||
| const std::vector<std::string> &DatasetFiles() const { return dataset_files_; } | |||
| const std::string &Task() const { return task_; } | |||
| const std::string &Usage() const { return usage_; } | |||
| int64_t NumSamples() const { return num_samples_; } | |||
| ShuffleMode Shuffle() const { return shuffle_; } | |||
| int32_t NumShards() const { return num_shards_; } | |||
| int32_t ShardId() const { return shard_id_; } | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| /// \brief Split string based on a character delimiter | |||
| /// \return A string vector | |||
| @@ -151,5 +151,22 @@ Status CocoNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_g | |||
| return Status::OK(); | |||
| } | |||
| Status CocoNode::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args, sampler_args; | |||
| RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); | |||
| args["sampler"] = sampler_args; | |||
| args["num_parallel_workers"] = num_workers_; | |||
| args["dataset_dir"] = dataset_dir_; | |||
| args["annotation_file"] = annotation_file_; | |||
| args["task"] = task_; | |||
| args["decode"] = decode_; | |||
| if (cache_ != nullptr) { | |||
| nlohmann::json cache_args; | |||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||
| args["cache"] = cache_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -69,6 +69,17 @@ class CocoNode : public MappableSourceNode { | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| /// \brief Getter functions | |||
| const std::string &DatasetDir() const { return dataset_dir_; } | |||
| const std::string &AnnotationFile() const { return annotation_file_; } | |||
| const std::string &Task() const { return task_; } | |||
| bool Decode() const { return decode_; } | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| std::string dataset_dir_; | |||
| std::string annotation_file_; | |||
| @@ -170,5 +170,23 @@ Status CSVNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_ge | |||
| return Status::OK(); | |||
| } | |||
| Status CSVNode::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["num_parallel_workers"] = num_workers_; | |||
| args["dataset_files"] = dataset_files_; | |||
| args["field_delim"] = std::string(1, field_delim_); | |||
| args["column_names"] = column_names_; | |||
| args["num_samples"] = num_samples_; | |||
| args["shuffle"] = shuffle_; | |||
| args["num_shards"] = num_shards_; | |||
| args["shard_id"] = shard_id_; | |||
| if (cache_ != nullptr) { | |||
| nlohmann::json cache_args; | |||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||
| args["cache"] = cache_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -92,6 +92,21 @@ class CSVNode : public NonMappableSourceNode { | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| /// \brief Getter functions | |||
| const std::vector<std::string> &DatasetFiles() const { return dataset_files_; } | |||
| char FieldDelim() const { return field_delim_; } | |||
| const std::vector<std::shared_ptr<CsvBase>> &ColumnDefaults() const { return column_defaults_; } | |||
| const std::vector<std::string> &ColumnNames() const { return column_names_; } | |||
| int64_t NumSamples() const { return num_samples_; } | |||
| ShuffleMode Shuffle() const { return shuffle_; } | |||
| int32_t NumShards() const { return num_shards_; } | |||
| int32_t ShardId() const { return shard_id_; } | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| std::vector<std::string> dataset_files_; | |||
| char field_delim_; | |||
| @@ -83,6 +83,12 @@ class GeneratorNode : public MappableSourceNode { | |||
| return Status::OK(); | |||
| } | |||
| /// \brief Getter functions | |||
| const py::function &GeneratorFunction() const { return generator_function_; } | |||
| const std::vector<std::string> &ColumnNames() const { return column_names_; } | |||
| const std::vector<DataType> &ColumnTypes() const { return column_types_; } | |||
| const std::shared_ptr<SchemaObj> &Schema() const { return schema_; } | |||
| private: | |||
| py::function generator_function_; | |||
| std::vector<std::string> column_names_; | |||
| @@ -124,5 +124,23 @@ Status ManifestNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si | |||
| return Status::OK(); | |||
| } | |||
| Status ManifestNode::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args, sampler_args; | |||
| RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); | |||
| args["sampler"] = sampler_args; | |||
| args["num_parallel_workers"] = num_workers_; | |||
| args["dataset_file"] = dataset_file_; | |||
| args["usage"] = usage_; | |||
| args["class_indexing"] = class_index_; | |||
| args["decode"] = decode_; | |||
| if (cache_ != nullptr) { | |||
| nlohmann::json cache_args; | |||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||
| args["cache"] = cache_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -70,6 +70,17 @@ class ManifestNode : public MappableSourceNode { | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| /// \brief Getter functions | |||
| const std::string &DatasetFile() const { return dataset_file_; } | |||
| const std::string &Usage() const { return usage_; } | |||
| bool Decode() const { return decode_; } | |||
| const std::map<std::string, int32_t> &ClassIndex() const { return class_index_; } | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| std::string dataset_file_; | |||
| std::string usage_; | |||
| @@ -91,6 +91,14 @@ class RandomNode : public NonMappableSourceNode { | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| /// \brief Getter functions | |||
| int32_t TotalRows() const { return total_rows_; } | |||
| const std::string &SchemaPath() const { return schema_path_; } | |||
| const std::shared_ptr<SchemaObj> &GetSchema() const { return schema_; } | |||
| const std::vector<std::string> &ColumnsList() const { return columns_list_; } | |||
| const std::mt19937 &RandGen() const { return rand_gen_; } | |||
| const std::unique_ptr<DataSchema> &GetDataSchema() const { return data_schema_; } | |||
| private: | |||
| /// \brief A quick inline for producing a random number between (and including) min/max | |||
| /// \param[in] min minimum number that can be generated. | |||
| @@ -136,5 +136,21 @@ Status TextFileNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si | |||
| return Status::OK(); | |||
| } | |||
| Status TextFileNode::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["num_parallel_workers"] = num_workers_; | |||
| args["dataset_files"] = dataset_files_; | |||
| args["num_samples"] = num_samples_; | |||
| args["shuffle"] = shuffle_; | |||
| args["num_shards"] = num_shards_; | |||
| args["shard_id"] = shard_id_; | |||
| if (cache_ != nullptr) { | |||
| nlohmann::json cache_args; | |||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||
| args["cache"] = cache_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -71,6 +71,18 @@ class TextFileNode : public NonMappableSourceNode { | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| /// \brief Getter functions | |||
| const std::vector<std::string> &DatasetFiles() const { return dataset_files_; } | |||
| int32_t NumSamples() const { return num_samples_; } | |||
| int32_t NumShards() const { return num_shards_; } | |||
| int32_t ShardId() const { return shard_id_; } | |||
| ShuffleMode Shuffle() const { return shuffle_; } | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| std::vector<std::string> dataset_files_; | |||
| int32_t num_samples_; | |||
| @@ -140,5 +140,23 @@ Status VOCNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_ge | |||
| return Status::OK(); | |||
| } | |||
| Status VOCNode::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args, sampler_args; | |||
| RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); | |||
| args["sampler"] = sampler_args; | |||
| args["num_parallel_workers"] = num_workers_; | |||
| args["dataset_dir"] = dataset_dir_; | |||
| args["task"] = task_; | |||
| args["usage"] = usage_; | |||
| args["class_indexing"] = class_index_; | |||
| args["decode"] = decode_; | |||
| if (cache_ != nullptr) { | |||
| nlohmann::json cache_args; | |||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||
| args["cache"] = cache_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -71,6 +71,18 @@ class VOCNode : public MappableSourceNode { | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| /// \brief Getter functions | |||
| const std::string &DatasetDir() const { return dataset_dir_; } | |||
| const std::string &Task() const { return task_; } | |||
| const std::string &Usage() const { return usage_; } | |||
| const std::map<std::string, int32_t> &ClassIndex() const { return class_index_; } | |||
| bool Decode() const { return decode_; } | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| const std::string kColumnImage = "image"; | |||
| const std::string kColumnTarget = "target"; | |||
| @@ -57,6 +57,10 @@ class SyncWaitNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| /// \brief Getter functions | |||
| const std::string &ConditionName() const { return condition_name_; } | |||
| const py::function &Callback() const { return callback_; } | |||
| private: | |||
| std::string condition_name_; | |||
| py::function callback_; | |||
| @@ -81,5 +81,12 @@ Status TakeNode::AcceptAfter(IRNodePass *const p, bool *const modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<TakeNode>(), modified); | |||
| } | |||
| Status TakeNode::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["count"] = take_count_; | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -80,6 +80,14 @@ class TakeNode : public DatasetNode { | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(IRNodePass *const p, bool *const modified) override; | |||
| /// \brief Getter functions | |||
| int32_t TakeCount() const { return take_count_; } | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| int32_t take_count_; | |||
| }; | |||
| @@ -116,5 +116,14 @@ Status TransferNode::AcceptAfter(IRNodePass *const p, bool *const modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<TransferNode>(), modified); | |||
| } | |||
| Status TransferNode::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["send_epoch_end"] = send_epoch_end_; | |||
| args["total_batch"] = total_batch_; | |||
| args["create_data_info_queue"] = create_data_info_queue_; | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -70,6 +70,20 @@ class TransferNode : public DatasetNode { | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(IRNodePass *const p, bool *const modified) override; | |||
| /// \brief Getter functions | |||
| const std::string &QueueName() const { return queue_name_; } | |||
| int32_t DeviceId() const { return device_id_; } | |||
| const std::string &DeviceType() const { return device_type_; } | |||
| int32_t PrefetchSize() const { return prefetch_size_; } | |||
| bool SendEpochEnd() const { return send_epoch_end_; } | |||
| int32_t TotalBatch() const { return total_batch_; } | |||
| bool CreateDataInfoQueue() const { return create_data_info_queue_; } | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| std::string queue_name_; | |||
| int32_t device_id_; | |||
| @@ -668,6 +668,8 @@ class PadOperation : public TensorOperation { | |||
| std::string Name() const override { return kPadOperation; } | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| std::vector<int32_t> padding_; | |||
| std::vector<uint8_t> fill_value_; | |||
| @@ -729,6 +731,8 @@ class RandomColorAdjustOperation : public TensorOperation { | |||
| std::string Name() const override { return kRandomColorAdjustOperation; } | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| std::vector<float> brightness_; | |||
| std::vector<float> contrast_; | |||
| @@ -750,6 +754,8 @@ class RandomCropOperation : public TensorOperation { | |||
| std::string Name() const override { return kRandomCropOperation; } | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| std::vector<int32_t> size_; | |||
| std::vector<int32_t> padding_; | |||
| @@ -936,6 +942,8 @@ class RandomRotationOperation : public TensorOperation { | |||
| std::string Name() const override { return kRandomRotationOperation; } | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| std::vector<float> degrees_; | |||
| InterpolationMode interpolation_mode_; | |||
| @@ -1037,6 +1045,8 @@ class RescaleOperation : public TensorOperation { | |||
| std::string Name() const override { return kRescaleOperation; } | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| float rescale_; | |||
| float shift_; | |||
| @@ -105,6 +105,8 @@ class CenterCropOperation : public TensorOperation { | |||
| std::string Name() const override { return kCenterCropOperation; } | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| std::vector<int32_t> size_; | |||
| }; | |||
| @@ -137,6 +139,8 @@ class DecodeOperation : public TensorOperation { | |||
| std::string Name() const override { return kDecodeOperation; } | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| bool rgb_; | |||
| }; | |||
| @@ -153,6 +157,8 @@ class NormalizeOperation : public TensorOperation { | |||
| std::string Name() const override { return kNormalizeOperation; } | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| std::vector<float> mean_; | |||
| std::vector<float> std_; | |||
| @@ -171,6 +177,8 @@ class ResizeOperation : public TensorOperation { | |||
| std::string Name() const override { return kResizeOperation; } | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| std::vector<int32_t> size_; | |||
| InterpolationMode interpolation_; | |||
| @@ -33,5 +33,12 @@ Status TypeCastOp::OutputType(const std::vector<DataType> &inputs, std::vector<D | |||
| outputs[0] = type_; | |||
| return Status::OK(); | |||
| } | |||
| Status TypeCastOp::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["data_type"] = type_.ToString(); | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -43,6 +43,8 @@ class TypeCastOp : public TensorOp { | |||
| std::string Name() const override { return kTypeCastOp; } | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| DataType type_; | |||
| }; | |||
| @@ -61,11 +61,5 @@ Status DecodeOp::OutputType(const std::vector<DataType> &inputs, std::vector<Dat | |||
| return Status::OK(); | |||
| } | |||
| Status DecodeOp::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["rgb"] = is_rgb_format_; | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -42,8 +42,6 @@ class DecodeOp : public TensorOp { | |||
| std::string Name() const override { return kDecodeOp; } | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| bool is_rgb_format_ = true; | |||
| }; | |||
| @@ -138,16 +138,5 @@ Status RandomCropOp::OutputShape(const std::vector<TensorShape> &inputs, std::ve | |||
| if (!outputs.empty()) return Status::OK(); | |||
| return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); | |||
| } | |||
| Status RandomCropOp::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["size"] = std::vector<int32_t>{crop_height_, crop_width_}; | |||
| args["padding"] = std::vector<int32_t>{pad_top_, pad_bottom_, pad_left_, pad_right_}; | |||
| args["pad_if_needed"] = pad_if_needed_; | |||
| args["fill_value"] = std::tuple<uint8_t, uint8_t, uint8_t>{fill_r_, fill_g_, fill_b_}; | |||
| args["padding_mode"] = border_type_; | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -79,8 +79,6 @@ class RandomCropOp : public TensorOp { | |||
| std::string Name() const override { return kRandomCropOp; } | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| protected: | |||
| int32_t crop_height_ = 0; | |||
| int32_t crop_width_ = 0; | |||
| @@ -29,13 +29,5 @@ Status RescaleOp::OutputType(const std::vector<DataType> &inputs, std::vector<Da | |||
| outputs[0] = DataType(DataType::DE_FLOAT32); | |||
| return Status::OK(); | |||
| } | |||
| Status RescaleOp::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["rescale"] = rescale_; | |||
| args["shift"] = shift_; | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -41,8 +41,6 @@ class RescaleOp : public TensorOp { | |||
| std::string Name() const override { return kRescaleOp; } | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| float rescale_; | |||
| float shift_; | |||
| @@ -67,13 +67,5 @@ Status ResizeOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector | |||
| if (!outputs.empty()) return Status::OK(); | |||
| return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); | |||
| } | |||
| Status ResizeOp::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["size"] = std::vector<int32_t>{size1_, size2_}; | |||
| args["interpolation"] = interpolation_; | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -61,8 +61,6 @@ class ResizeOp : public TensorOp { | |||
| std::string Name() const override { return kResizeOp; } | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| protected: | |||
| int32_t size1_; | |||
| int32_t size2_; | |||
| @@ -167,19 +167,82 @@ def create_node(node): | |||
| pyobj = None | |||
| # Find a matching Dataset class and call the constructor with the corresponding args. | |||
| # When a new Dataset class is introduced, another if clause and parsing code needs to be added. | |||
| if dataset_op == 'ImageFolderDataset': | |||
| # Dataset Source Ops (in alphabetical order) | |||
| if dataset_op == 'CelebADataset': | |||
| sampler = construct_sampler(node.get('sampler')) | |||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||
| pyobj = pyclass(node['dataset_dir'], node.get('num_parallel_workers'), node.get('shuffle'), node.get('usage'), | |||
| sampler, node.get('decode'), node.get('extensions'), num_samples, node.get('num_shards'), | |||
| node.get('shard_id')) | |||
| elif dataset_op == 'Cifar10Dataset': | |||
| sampler = construct_sampler(node.get('sampler')) | |||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||
| pyobj = pyclass(node['dataset_dir'], node['usage'], num_samples, node.get('num_parallel_workers'), | |||
| node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id')) | |||
| elif dataset_op == 'Cifar100Dataset': | |||
| sampler = construct_sampler(node.get('sampler')) | |||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||
| pyobj = pyclass(node['dataset_dir'], node['usage'], num_samples, node.get('num_parallel_workers'), | |||
| node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id')) | |||
| elif dataset_op == 'ClueDataset': | |||
| shuffle = to_shuffle_mode(node.get('shuffle')) | |||
| if shuffle is not None and isinstance(shuffle, str): | |||
| shuffle = de.Shuffle(shuffle) | |||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||
| pyobj = pyclass(node['dataset_files'], node.get('task'), | |||
| node.get('usage'), num_samples, node.get('num_parallel_workers'), shuffle, | |||
| node.get('num_shards'), node.get('shard_id')) | |||
| elif dataset_op == 'CocoDataset': | |||
| sampler = construct_sampler(node.get('sampler')) | |||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||
| pyobj = pyclass(node['dataset_dir'], node.get('annotation_file'), node.get('task'), num_samples, | |||
| node.get('num_parallel_workers'), node.get('shuffle'), node.get('decode'), sampler, | |||
| node.get('num_shards'), node.get('shard_id')) | |||
| elif dataset_op == 'CSVDataset': | |||
| shuffle = to_shuffle_mode(node.get('shuffle')) | |||
| if shuffle is not None and isinstance(shuffle, str): | |||
| shuffle = de.Shuffle(shuffle) | |||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||
| pyobj = pyclass(node['dataset_files'], node.get('field_delim'), | |||
| node.get('column_defaults'), node.get('column_names'), num_samples, | |||
| node.get('num_parallel_workers'), shuffle, | |||
| node.get('num_shards'), node.get('shard_id')) | |||
| elif dataset_op == 'ImageFolderDataset': | |||
| sampler = construct_sampler(node.get('sampler')) | |||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||
| pyobj = pyclass(node['dataset_dir'], num_samples, node.get('num_parallel_workers'), | |||
| node.get('shuffle'), sampler, node.get('extensions'), | |||
| node.get('class_indexing'), node.get('decode'), node.get('num_shards'), | |||
| node.get('shard_id'), node.get('cache')) | |||
| node.get('shard_id')) | |||
| elif dataset_op == 'ManifestDataset': | |||
| sampler = construct_sampler(node.get('sampler')) | |||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||
| pyobj = pyclass(node['dataset_file'], node['usage'], num_samples, | |||
| node.get('num_parallel_workers'), node.get('shuffle'), sampler, | |||
| node.get('class_indexing'), node.get('decode'), node.get('num_shards'), | |||
| node.get('shard_id')) | |||
| elif dataset_op == 'MnistDataset': | |||
| sampler = construct_sampler(node.get('sampler')) | |||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||
| pyobj = pyclass(node['dataset_dir'], node['usage'], num_samples, node.get('num_parallel_workers'), | |||
| node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'), node.get('cache')) | |||
| node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id')) | |||
| elif dataset_op == 'TextFileDataset': | |||
| shuffle = to_shuffle_mode(node.get('shuffle')) | |||
| if shuffle is not None and isinstance(shuffle, str): | |||
| shuffle = de.Shuffle(shuffle) | |||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||
| pyobj = pyclass(node['dataset_files'], num_samples, | |||
| node.get('num_parallel_workers'), shuffle, | |||
| node.get('num_shards'), node.get('shard_id')) | |||
| elif dataset_op == 'TFRecordDataset': | |||
| shuffle = to_shuffle_mode(node.get('shuffle')) | |||
| @@ -188,30 +251,50 @@ def create_node(node): | |||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||
| pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('columns_list'), | |||
| num_samples, node.get('num_parallel_workers'), | |||
| shuffle, node.get('num_shards'), node.get('shard_id'), node.get('cache')) | |||
| shuffle, node.get('num_shards'), node.get('shard_id')) | |||
| elif dataset_op == 'Repeat': | |||
| pyobj = de.Dataset().repeat(node.get('count')) | |||
| elif dataset_op == 'VOCDataset': | |||
| sampler = construct_sampler(node.get('sampler')) | |||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||
| pyobj = pyclass(node['dataset_dir'], node.get('task'), node.get('usage'), node.get('class_indexing'), | |||
| num_samples, node.get('num_parallel_workers'), node.get('shuffle'), | |||
| node.get('decode'), sampler, node.get('num_shards'), node.get('shard_id')) | |||
| # Dataset Ops (in alphabetical order) | |||
| elif dataset_op == 'Batch': | |||
| pyobj = de.Dataset().batch(node['batch_size'], node.get('drop_remainder')) | |||
| elif dataset_op == 'Map': | |||
| tensor_ops = construct_tensor_ops(node.get('operations')) | |||
| pyobj = de.Dataset().map(tensor_ops, node.get('input_columns'), node.get('output_columns'), | |||
| node.get('column_order'), node.get('num_parallel_workers'), | |||
| True, node.get('cache'), node.get('callbacks')) | |||
| True, node.get('callbacks')) | |||
| elif dataset_op == 'Project': | |||
| pyobj = de.Dataset().project(node['columns']) | |||
| elif dataset_op == 'Rename': | |||
| pyobj = de.Dataset().rename(node['input_columns'], node['output_columns']) | |||
| elif dataset_op == 'Repeat': | |||
| pyobj = de.Dataset().repeat(node.get('count')) | |||
| elif dataset_op == 'Shuffle': | |||
| pyobj = de.Dataset().shuffle(node.get('buffer_size')) | |||
| elif dataset_op == 'Batch': | |||
| pyobj = de.Dataset().batch(node['batch_size'], node.get('drop_remainder')) | |||
| elif dataset_op == 'Skip': | |||
| pyobj = de.Dataset().skip(node.get('count')) | |||
| elif dataset_op == 'Take': | |||
| pyobj = de.Dataset().take(node.get('count')) | |||
| elif dataset_op == 'Transfer': | |||
| pyobj = de.Dataset().to_device(node.get('send_epoch_end'), node.get('create_data_info_queue')) | |||
| elif dataset_op == 'Zip': | |||
| # Create ZipDataset instance, giving dummy input dataset that will be overrided in the caller. | |||
| pyobj = de.ZipDataset((de.Dataset(), de.Dataset())) | |||
| elif dataset_op == 'Rename': | |||
| pyobj = de.Dataset().rename(node['input_columns'], node['output_columns']) | |||
| else: | |||
| raise RuntimeError(dataset_op + " is not yet supported by ds.engine.deserialize().") | |||
| @@ -252,35 +335,59 @@ def construct_tensor_ops(operations): | |||
| """Instantiate tensor op object(s) based on the information from dictionary['operations']""" | |||
| result = [] | |||
| for op in operations: | |||
| op_name = op['tensor_op_name'][:-2] # to remove op from the back of the name | |||
| op_name = op['tensor_op_name'] | |||
| op_module_vis = sys.modules["mindspore.dataset.vision.c_transforms"] | |||
| op_module_trans = sys.modules["mindspore.dataset.transforms.c_transforms"] | |||
| if op_name == "HwcToChw": op_name = "HWC2CHW" | |||
| if hasattr(op_module_vis, op_name): | |||
| op_class = getattr(op_module_vis, op_name) | |||
| elif hasattr(op_module_trans, op_name): | |||
| elif hasattr(op_module_trans, op_name[:-2]): | |||
| op_name = op_name[:-2] # to remove op from the back of the name | |||
| op_class = getattr(op_module_trans, op_name) | |||
| else: | |||
| raise RuntimeError(op_name + " is not yet supported by deserialize().") | |||
| if op_name == 'Decode': | |||
| # Transforms Ops (in alphabetical order) | |||
| if op_name == 'OneHot': | |||
| result.append(op_class(op['num_classes'])) | |||
| elif op_name == 'TypeCast': | |||
| result.append(op_class(to_mstype(op['data_type']))) | |||
| # Vision Ops (in alphabetical order) | |||
| elif op_name == 'CenterCrop': | |||
| result.append(op_class(op['size'])) | |||
| elif op_name == 'Decode': | |||
| result.append(op_class(op.get('rgb'))) | |||
| elif op_name == 'HWC2CHW': | |||
| result.append(op_class()) | |||
| elif op_name == 'Normalize': | |||
| result.append(op_class(op['mean'], op['std'])) | |||
| elif op_name == 'Pad': | |||
| result.append(op_class(op['padding'], tuple(op['fill_value']), Border(to_border_mode(op['padding_mode'])))) | |||
| elif op_name == 'RandomColorAdjust': | |||
| result.append(op_class(op.get('brightness'), op.get('contrast'), op.get('saturation'), | |||
| op.get('hue'))) | |||
| elif op_name == 'RandomCrop': | |||
| result.append(op_class(op['size'], op.get('padding'), op.get('pad_if_needed'), | |||
| tuple(op.get('fill_value')), Border(to_border_mode(op.get('padding_mode'))))) | |||
| elif op_name == 'Resize': | |||
| result.append(op_class(op['size'], Inter(to_interpolation_mode(op.get('interpolation'))))) | |||
| elif op_name == 'RandomRotation': | |||
| result.append(op_class(op['degrees'], to_interpolation_mode(op.get('interpolation_mode')), op.get('expand'), | |||
| tuple(op.get('center')), tuple(op.get('fill_value')))) | |||
| elif op_name == 'Rescale': | |||
| result.append(op_class(op['rescale'], op['shift'])) | |||
| elif op_name == 'HWC2CHW': | |||
| result.append(op_class()) | |||
| elif op_name == 'OneHot': | |||
| result.append(op_class(op['num_classes'])) | |||
| elif op_name == 'Resize': | |||
| result.append(op_class(op['size'], to_interpolation_mode(op.get('interpolation')))) | |||
| else: | |||
| raise ValueError("Tensor op name is unknown: {}.".format(op_name)) | |||
| @@ -19,11 +19,13 @@ import filecmp | |||
| import glob | |||
| import json | |||
| import os | |||
| import pytest | |||
| import numpy as np | |||
| from test_minddataset_sampler import add_and_remove_cv_file, get_data, CV_DIR_NAME, CV_FILE_NAME | |||
| from util import config_get_set_num_parallel_workers, config_get_set_seed | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.c_transforms as c | |||
| import mindspore.dataset.vision.c_transforms as vision | |||
| @@ -31,7 +33,7 @@ from mindspore import log as logger | |||
| from mindspore.dataset.vision import Inter | |||
| def skip_test_imagefolder(remove_json_files=True): | |||
| def test_serdes_imagefolder_dataset(remove_json_files=True): | |||
| """ | |||
| Test simulating resnet50 dataset pipeline. | |||
| """ | |||
| @@ -100,7 +102,7 @@ def skip_test_imagefolder(remove_json_files=True): | |||
| delete_json_files() | |||
| def test_mnist_dataset(remove_json_files=True): | |||
| def test_serdes_mnist_dataset(remove_json_files=True): | |||
| """ | |||
| Test serdes on mnist dataset pipeline. | |||
| """ | |||
| @@ -141,7 +143,7 @@ def test_mnist_dataset(remove_json_files=True): | |||
| delete_json_files() | |||
| def test_zip_dataset(remove_json_files=True): | |||
| def test_serdes_zip_dataset(remove_json_files=True): | |||
| """ | |||
| Test serdes on zip dataset pipeline. | |||
| """ | |||
| @@ -185,7 +187,7 @@ def test_zip_dataset(remove_json_files=True): | |||
| delete_json_files() | |||
| def skip_test_random_crop(): | |||
| def test_serdes_random_crop(): | |||
| """ | |||
| Test serdes on RandomCrop pipeline. | |||
| """ | |||
| @@ -225,6 +227,179 @@ def skip_test_random_crop(): | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| def test_serdes_cifar10_dataset(remove_json_files=True): | |||
| """ | |||
| Test serdes on Cifar10 dataset pipeline | |||
| """ | |||
| data_dir = "../data/dataset/testCifar10Data" | |||
| original_seed = config_get_set_seed(1) | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||
| data1 = ds.Cifar10Dataset(data_dir, num_samples=10, shuffle=False) | |||
| data1 = data1.take(6) | |||
| trans = [ | |||
| vision.RandomCrop((32, 32), (4, 4, 4, 4)), | |||
| vision.Resize((224, 224)), | |||
| vision.Rescale(1.0 / 255.0, 0.0), | |||
| vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]), | |||
| vision.HWC2CHW() | |||
| ] | |||
| type_cast_op = c.TypeCast(mstype.int32) | |||
| data1 = data1.map(operations=type_cast_op, input_columns="label") | |||
| data1 = data1.map(operations=trans, input_columns="image") | |||
| data1 = data1.batch(3, drop_remainder=True) | |||
| data1 = data1.repeat(1) | |||
| data2 = util_check_serialize_deserialize_file(data1, "cifar10_dataset_pipeline", remove_json_files) | |||
| num_samples = 0 | |||
| # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2) | |||
| for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), | |||
| data2.create_dict_iterator(num_epochs=1, output_numpy=True)): | |||
| np.testing.assert_array_equal(item1['image'], item2['image']) | |||
| num_samples += 1 | |||
| assert num_samples == 2 | |||
| # Restore configuration num_parallel_workers | |||
| ds.config.set_seed(original_seed) | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| def test_serdes_celeba_dataset(remove_json_files=True): | |||
| """ | |||
| Test serdes on Celeba dataset pipeline. | |||
| """ | |||
| DATA_DIR = "../data/dataset/testCelebAData/" | |||
| data1 = ds.CelebADataset(DATA_DIR, decode=True, num_shards=1, shard_id=0) | |||
| # define map operations | |||
| data1 = data1.repeat(2) | |||
| center_crop = vision.CenterCrop((80, 80)) | |||
| pad_op = vision.Pad(20, fill_value=(20, 20, 20)) | |||
| data1 = data1.map(operations=[center_crop, pad_op], input_columns=["image"], num_parallel_workers=8) | |||
| data2 = util_check_serialize_deserialize_file(data1, "celeba_dataset_pipeline", remove_json_files) | |||
| num_samples = 0 | |||
| # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2) | |||
| for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), | |||
| data2.create_dict_iterator(num_epochs=1, output_numpy=True)): | |||
| np.testing.assert_array_equal(item1['image'], item2['image']) | |||
| num_samples += 1 | |||
| assert num_samples == 8 | |||
| def test_serdes_csv_dataset(remove_json_files=True): | |||
| """ | |||
| Test serdes on Csvdataset pipeline. | |||
| """ | |||
| DATA_DIR = "../data/dataset/testCSV/1.csv" | |||
| data1 = ds.CSVDataset( | |||
| DATA_DIR, | |||
| column_defaults=["1", "2", "3", "4"], | |||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||
| shuffle=False) | |||
| columns = ["col1", "col4", "col2"] | |||
| data1 = data1.project(columns=columns) | |||
| data2 = util_check_serialize_deserialize_file(data1, "csv_dataset_pipeline", remove_json_files) | |||
| num_samples = 0 | |||
| # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2) | |||
| for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), | |||
| data2.create_dict_iterator(num_epochs=1, output_numpy=True)): | |||
| np.testing.assert_array_equal(item1['col1'], item2['col1']) | |||
| np.testing.assert_array_equal(item1['col2'], item2['col2']) | |||
| np.testing.assert_array_equal(item1['col4'], item2['col4']) | |||
| num_samples += 1 | |||
| assert num_samples == 3 | |||
| def test_serdes_voc_dataset(remove_json_files=True): | |||
| """ | |||
| Test serdes on VOC dataset pipeline. | |||
| """ | |||
| data_dir = "../data/dataset/testVOC2012" | |||
| original_seed = config_get_set_seed(1) | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||
| # define map operations | |||
| random_color_adjust_op = vision.RandomColorAdjust(brightness=(0.5, 0.5)) | |||
| random_rotation_op = vision.RandomRotation((0, 90), expand=True, resample=Inter.BILINEAR, center=(50, 50), | |||
| fill_value=150) | |||
| data1 = ds.VOCDataset(data_dir, task="Detection", usage="train", decode=True) | |||
| data1 = data1.map(operations=random_color_adjust_op, input_columns=["image"]) | |||
| data1 = data1.map(operations=random_rotation_op, input_columns=["image"]) | |||
| data1 = data1.skip(2) | |||
| data2 = util_check_serialize_deserialize_file(data1, "voc_dataset_pipeline", remove_json_files) | |||
| num_samples = 0 | |||
| # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2) | |||
| for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), | |||
| data2.create_dict_iterator(num_epochs=1, output_numpy=True)): | |||
| np.testing.assert_array_equal(item1['image'], item2['image']) | |||
| num_samples += 1 | |||
| assert num_samples == 7 | |||
| # Restore configuration num_parallel_workers | |||
| ds.config.set_seed(original_seed) | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| def test_serdes_to_device(remove_json_files=True): | |||
| """ | |||
| Test serdes on VOC dataset pipeline. | |||
| """ | |||
| data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||
| schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||
| data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False) | |||
| data1 = data1.to_device() | |||
| util_check_serialize_deserialize_file(data1, "transfer_dataset_pipeline", remove_json_files) | |||
| def test_serdes_exception(): | |||
| """ | |||
| Test exception case in serdes | |||
| """ | |||
| data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||
| schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||
| data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False) | |||
| data1 = data1.filter(input_columns=["image", "label"], predicate=lambda data: data < 11, num_parallel_workers=4) | |||
| data1_json = ds.serialize(data1) | |||
| with pytest.raises(RuntimeError) as msg: | |||
| ds.deserialize(input_dict=data1_json) | |||
| assert "Filter is not yet supported by ds.engine.deserialize" in str(msg) | |||
| def util_check_serialize_deserialize_file(data_orig, filename, remove_json_files): | |||
| """ | |||
| Utility function for testing serdes files. It is to check if a json file is indeed created with correct name | |||
| after serializing and if it remains the same after repeatly saving and loading. | |||
| :param data_orig: original data pipeline to be serialized | |||
| :param filename: filename to be saved as json format | |||
| :param remove_json_files: whether to remove the json file after testing | |||
| :return: The data pipeline after serializing and deserializing using the original pipeline | |||
| """ | |||
| file1 = filename + ".json" | |||
| file2 = filename + "_1.json" | |||
| ds.serialize(data_orig, file1) | |||
| assert validate_jsonfile(file1) is True | |||
| assert validate_jsonfile("wrong_name.json") is False | |||
| data_changed = ds.deserialize(json_filepath=file1) | |||
| ds.serialize(data_changed, file2) | |||
| assert validate_jsonfile(file2) is True | |||
| assert filecmp.cmp(file1, file2) | |||
| # Remove the generated json file | |||
| if remove_json_files: | |||
| delete_json_files() | |||
| return data_changed | |||
| def validate_jsonfile(filepath): | |||
| try: | |||
| file_exist = os.path.exists(filepath) | |||
| @@ -276,7 +451,12 @@ def skip_test_minddataset(add_and_remove_cv_file): | |||
| if __name__ == '__main__': | |||
| test_imagefolder() | |||
| test_zip_dataset() | |||
| test_mnist_dataset() | |||
| test_random_crop() | |||
| test_serdes_imagefolder_dataset() | |||
| test_serdes_mnist_dataset() | |||
| test_serdes_cifar10_dataset() | |||
| test_serdes_celeba_dataset() | |||
| test_serdes_csv_dataset() | |||
| test_serdes_voc_dataset() | |||
| test_serdes_zip_dataset() | |||
| test_serdes_random_crop() | |||
| test_serdes_exception() | |||