| @@ -526,6 +526,11 @@ std::shared_ptr<TensorOp> CenterCropOperation::Build() { | |||||
| return tensor_op; | return tensor_op; | ||||
| } | } | ||||
| Status CenterCropOperation::to_json(nlohmann::json *out_json) { | |||||
| (*out_json)["size"] = size_; | |||||
| return Status::OK(); | |||||
| } | |||||
| // CropOperation. | // CropOperation. | ||||
| CropOperation::CropOperation(std::vector<int32_t> coordinates, std::vector<int32_t> size) | CropOperation::CropOperation(std::vector<int32_t> coordinates, std::vector<int32_t> size) | ||||
| : coordinates_(coordinates), size_(size) {} | : coordinates_(coordinates), size_(size) {} | ||||
| @@ -638,6 +643,11 @@ Status DecodeOperation::ValidateParams() { return Status::OK(); } | |||||
| std::shared_ptr<TensorOp> DecodeOperation::Build() { return std::make_shared<DecodeOp>(rgb_); } | std::shared_ptr<TensorOp> DecodeOperation::Build() { return std::make_shared<DecodeOp>(rgb_); } | ||||
| Status DecodeOperation::to_json(nlohmann::json *out_json) { | |||||
| (*out_json)["rgb"] = rgb_; | |||||
| return Status::OK(); | |||||
| } | |||||
| // EqualizeOperation | // EqualizeOperation | ||||
| Status EqualizeOperation::ValidateParams() { return Status::OK(); } | Status EqualizeOperation::ValidateParams() { return Status::OK(); } | ||||
| @@ -801,6 +811,14 @@ std::shared_ptr<TensorOp> NormalizeOperation::Build() { | |||||
| return std::make_shared<NormalizeOp>(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2]); | return std::make_shared<NormalizeOp>(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2]); | ||||
| } | } | ||||
| Status NormalizeOperation::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["mean"] = mean_; | |||||
| args["std"] = std_; | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| // NormalizePadOperation | // NormalizePadOperation | ||||
| NormalizePadOperation::NormalizePadOperation(const std::vector<float> &mean, const std::vector<float> &std, | NormalizePadOperation::NormalizePadOperation(const std::vector<float> &mean, const std::vector<float> &std, | ||||
| @@ -893,6 +911,15 @@ std::shared_ptr<TensorOp> PadOperation::Build() { | |||||
| return tensor_op; | return tensor_op; | ||||
| } | } | ||||
| Status PadOperation::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["padding"] = padding_; | |||||
| args["fill_value"] = fill_value_; | |||||
| args["padding_mode"] = padding_mode_; | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| // RandomAffineOperation | // RandomAffineOperation | ||||
| RandomAffineOperation::RandomAffineOperation(const std::vector<float_t> °rees, | RandomAffineOperation::RandomAffineOperation(const std::vector<float_t> °rees, | ||||
| const std::vector<float_t> &translate_range, | const std::vector<float_t> &translate_range, | ||||
| @@ -1188,6 +1215,16 @@ std::shared_ptr<TensorOp> RandomColorAdjustOperation::Build() { | |||||
| return tensor_op; | return tensor_op; | ||||
| } | } | ||||
| Status RandomColorAdjustOperation::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["brightness"] = brightness_; | |||||
| args["contrast"] = contrast_; | |||||
| args["saturation"] = saturation_; | |||||
| args["hue"] = hue_; | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| // RandomCropOperation | // RandomCropOperation | ||||
| RandomCropOperation::RandomCropOperation(std::vector<int32_t> size, std::vector<int32_t> padding, bool pad_if_needed, | RandomCropOperation::RandomCropOperation(std::vector<int32_t> size, std::vector<int32_t> padding, bool pad_if_needed, | ||||
| std::vector<uint8_t> fill_value, BorderType padding_mode) | std::vector<uint8_t> fill_value, BorderType padding_mode) | ||||
| @@ -1261,6 +1298,17 @@ std::shared_ptr<TensorOp> RandomCropOperation::Build() { | |||||
| return tensor_op; | return tensor_op; | ||||
| } | } | ||||
| Status RandomCropOperation::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["size"] = size_; | |||||
| args["padding"] = padding_; | |||||
| args["pad_if_needed"] = pad_if_needed_; | |||||
| args["fill_value"] = fill_value_; | |||||
| args["padding_mode"] = padding_mode_; | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| // RandomCropDecodeResizeOperation | // RandomCropDecodeResizeOperation | ||||
| RandomCropDecodeResizeOperation::RandomCropDecodeResizeOperation(std::vector<int32_t> size, std::vector<float> scale, | RandomCropDecodeResizeOperation::RandomCropDecodeResizeOperation(std::vector<int32_t> size, std::vector<float> scale, | ||||
| std::vector<float> ratio, | std::vector<float> ratio, | ||||
| @@ -1735,6 +1783,17 @@ std::shared_ptr<TensorOp> RandomRotationOperation::Build() { | |||||
| return tensor_op; | return tensor_op; | ||||
| } | } | ||||
| Status RandomRotationOperation::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["degrees"] = degrees_; | |||||
| args["interpolation_mode"] = interpolation_mode_; | |||||
| args["expand"] = expand_; | |||||
| args["center"] = center_; | |||||
| args["fill_value"] = fill_value_; | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| // RandomSelectSubpolicyOperation. | // RandomSelectSubpolicyOperation. | ||||
| RandomSelectSubpolicyOperation::RandomSelectSubpolicyOperation( | RandomSelectSubpolicyOperation::RandomSelectSubpolicyOperation( | ||||
| std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> policy) | std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> policy) | ||||
| @@ -1889,6 +1948,14 @@ std::shared_ptr<TensorOp> RescaleOperation::Build() { | |||||
| return tensor_op; | return tensor_op; | ||||
| } | } | ||||
| Status RescaleOperation::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["rescale"] = rescale_; | |||||
| args["shift"] = shift_; | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| #endif | #endif | ||||
| // ResizeOperation | // ResizeOperation | ||||
| ResizeOperation::ResizeOperation(std::vector<int32_t> size, InterpolationMode interpolation) | ResizeOperation::ResizeOperation(std::vector<int32_t> size, InterpolationMode interpolation) | ||||
| @@ -1920,6 +1987,14 @@ std::shared_ptr<TensorOp> ResizeOperation::Build() { | |||||
| return std::make_shared<ResizeOp>(height, width, interpolation_); | return std::make_shared<ResizeOp>(height, width, interpolation_); | ||||
| } | } | ||||
| Status ResizeOperation::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["size"] = size_; | |||||
| args["interpolation"] = interpolation_; | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| // RotateOperation | // RotateOperation | ||||
| RotateOperation::RotateOperation() { rotate_op = std::make_shared<RotateOp>(0); } | RotateOperation::RotateOperation() { rotate_op = std::make_shared<RotateOp>(0); } | ||||
| @@ -44,18 +44,5 @@ Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr<Data | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status DatasetCacheImpl::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["session_id"] = session_id_; | |||||
| args["cache_memory_size"] = cache_mem_sz_; | |||||
| args["spill"] = spill_; | |||||
| if (hostname_) args["hostname"] = hostname_.value(); | |||||
| if (port_) args["port"] = port_.value(); | |||||
| if (num_connections_) args["num_connections"] = num_connections_.value(); | |||||
| if (prefetch_sz_) args["prefetch_size"] = prefetch_sz_.value(); | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -60,8 +60,6 @@ class DatasetCacheImpl : public DatasetCache { | |||||
| ~DatasetCacheImpl() = default; | ~DatasetCacheImpl() = default; | ||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| std::shared_ptr<CacheClient> cache_client_; | std::shared_ptr<CacheClient> cache_client_; | ||||
| session_id_type session_id_; | session_id_type session_id_; | ||||
| @@ -36,5 +36,15 @@ Status PreBuiltDatasetCache::CreateCacheOp(int32_t num_workers, std::shared_ptr< | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status PreBuiltDatasetCache::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["session_id"] = cache_client_->session_id(); | |||||
| args["cache_memory_size"] = cache_client_->GetCacheMemSz(); | |||||
| args["spill"] = cache_client_->isSpill(); | |||||
| args["num_connections"] = cache_client_->GetNumConnections(); | |||||
| args["prefetch_size"] = cache_client_->GetPrefetchSize(); | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -42,6 +42,8 @@ class PreBuiltDatasetCache : public DatasetCache { | |||||
| Status ValidateParams() override { return Status::OK(); } | Status ValidateParams() override { return Status::OK(); } | ||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| std::shared_ptr<CacheClient> cache_client_; | std::shared_ptr<CacheClient> cache_client_; | ||||
| }; | }; | ||||
| @@ -63,6 +63,15 @@ class BucketBatchByLengthNode : public DatasetNode { | |||||
| bool IsSizeDefined() override { return false; }; | bool IsSizeDefined() override { return false; }; | ||||
| /// \brief Getter functions | |||||
| const std::vector<std::string> &ColumnNames() const { return column_names_; } | |||||
| const std::vector<int32_t> &BucketBoundaries() const { return bucket_boundaries_; } | |||||
| const std::vector<int32_t> &BucketBatchSizes() const { return bucket_batch_sizes_; } | |||||
| const std::shared_ptr<TensorOp> &ElementLengthFunction() const { return element_length_function_; } | |||||
| const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &PadInfo() const { return pad_info_; } | |||||
| bool PadToBucketBoundary() const { return pad_to_bucket_boundary_; } | |||||
| bool DropRemainder() const { return drop_remainder_; } | |||||
| private: | private: | ||||
| std::vector<std::string> column_names_; | std::vector<std::string> column_names_; | ||||
| std::vector<int32_t> bucket_boundaries_; | std::vector<int32_t> bucket_boundaries_; | ||||
| @@ -72,6 +72,14 @@ class BuildSentenceVocabNode : public DatasetNode { | |||||
| /// \return Status of the node visit | /// \return Status of the node visit | ||||
| Status AcceptAfter(IRNodePass *const p, bool *const modified) override; | Status AcceptAfter(IRNodePass *const p, bool *const modified) override; | ||||
| /// \brief Getter functions | |||||
| const std::shared_ptr<SentencePieceVocab> &GetVocab() const { return vocab_; } | |||||
| const std::vector<std::string> &ColNames() const { return col_names_; } | |||||
| int32_t VocabSize() const { return vocab_size_; } | |||||
| float CharacterCoverage() const { return character_coverage_; } | |||||
| SentencePieceModel ModelType() const { return model_type_; } | |||||
| const std::unordered_map<std::string, std::string> &Params() const { return params_; } | |||||
| private: | private: | ||||
| std::shared_ptr<SentencePieceVocab> vocab_; | std::shared_ptr<SentencePieceVocab> vocab_; | ||||
| std::vector<std::string> col_names_; | std::vector<std::string> col_names_; | ||||
| @@ -70,6 +70,14 @@ class BuildVocabNode : public DatasetNode { | |||||
| /// \return Status of the node visit | /// \return Status of the node visit | ||||
| Status AcceptAfter(IRNodePass *const p, bool *const modified) override; | Status AcceptAfter(IRNodePass *const p, bool *const modified) override; | ||||
| /// \brief Getter functions | |||||
| const std::shared_ptr<Vocab> &GetVocab() const { return vocab_; } | |||||
| const std::vector<std::string> &Columns() const { return columns_; } | |||||
| const std::pair<int64_t, int64_t> &FreqRange() const { return freq_range_; } | |||||
| int64_t TopK() const { return top_k_; } | |||||
| const std::vector<std::string> &SpecialTokens() const { return special_tokens_; } | |||||
| bool SpecialFirst() const { return special_first_; } | |||||
| private: | private: | ||||
| std::shared_ptr<Vocab> vocab_; | std::shared_ptr<Vocab> vocab_; | ||||
| std::vector<std::string> columns_; | std::vector<std::string> columns_; | ||||
| @@ -61,6 +61,10 @@ class ConcatNode : public DatasetNode { | |||||
| bool IsSizeDefined() override { return false; } | bool IsSizeDefined() override { return false; } | ||||
| /// \brief Getter functions | |||||
| const std::vector<std::pair<int, int>> &ChildrenFlagAndNums() const { return children_flag_and_nums_; } | |||||
| const std::vector<std::pair<int, int>> &ChildrenStartEndIndex() const { return children_start_end_index_; } | |||||
| private: | private: | ||||
| std::shared_ptr<SamplerObj> sampler_; | std::shared_ptr<SamplerObj> sampler_; | ||||
| std::vector<std::pair<int, int>> children_flag_and_nums_; | std::vector<std::pair<int, int>> children_flag_and_nums_; | ||||
| @@ -73,5 +73,13 @@ Status FilterNode::AcceptAfter(IRNodePass *const p, bool *const modified) { | |||||
| return p->VisitAfter(shared_from_base<FilterNode>(), modified); | return p->VisitAfter(shared_from_base<FilterNode>(), modified); | ||||
| } | } | ||||
| Status FilterNode::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["input_columns"] = input_columns_; | |||||
| args["num_parallel_workers"] = num_workers_; | |||||
| args["predicate"] = "pyfunc"; | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -70,6 +70,15 @@ class FilterNode : public DatasetNode { | |||||
| /// \return Status of the node visit | /// \return Status of the node visit | ||||
| Status AcceptAfter(IRNodePass *const p, bool *const modified) override; | Status AcceptAfter(IRNodePass *const p, bool *const modified) override; | ||||
| /// \brief Getter functions | |||||
| const std::shared_ptr<TensorOp> &Predicate() const { return predicate_; } | |||||
| const std::vector<std::string> &InputColumns() const { return input_columns_; } | |||||
| /// \brief Get the arguments of node | |||||
| /// \param[out] out_json JSON string of all attributes | |||||
| /// \return Status of the function | |||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| std::shared_ptr<TensorOp> predicate_; | std::shared_ptr<TensorOp> predicate_; | ||||
| std::vector<std::string> input_columns_; | std::vector<std::string> input_columns_; | ||||
| @@ -138,8 +138,8 @@ Status MapNode::to_json(nlohmann::json *out_json) { | |||||
| std::vector<nlohmann::json> ops; | std::vector<nlohmann::json> ops; | ||||
| std::vector<int32_t> cbs; | std::vector<int32_t> cbs; | ||||
| nlohmann::json op_args; | |||||
| for (auto op : operations_) { | for (auto op : operations_) { | ||||
| nlohmann::json op_args; | |||||
| RETURN_IF_NOT_OK(op->to_json(&op_args)); | RETURN_IF_NOT_OK(op->to_json(&op_args)); | ||||
| op_args["tensor_op_name"] = op->Name(); | op_args["tensor_op_name"] = op->Name(); | ||||
| ops.push_back(op_args); | ops.push_back(op_args); | ||||
| @@ -57,5 +57,11 @@ Status ProjectNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_op | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status ProjectNode::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["columns"] = columns_; | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -55,6 +55,14 @@ class ProjectNode : public DatasetNode { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Getter functions | |||||
| const std::vector<std::string> &Columns() const { return columns_; } | |||||
| /// \brief Get the arguments of node | |||||
| /// \param[out] out_json JSON string of all attributes | |||||
| /// \return Status of the function | |||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| std::vector<std::string> columns_; | std::vector<std::string> columns_; | ||||
| }; | }; | ||||
| @@ -29,7 +29,7 @@ namespace dataset { | |||||
| class RootNode : public DatasetNode { | class RootNode : public DatasetNode { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| RootNode() : DatasetNode() {} | |||||
| RootNode() : DatasetNode(), num_epochs_(0) {} | |||||
| /// \brief Constructor | /// \brief Constructor | ||||
| explicit RootNode(std::shared_ptr<DatasetNode> child); | explicit RootNode(std::shared_ptr<DatasetNode> child); | ||||
| @@ -83,5 +83,12 @@ Status SkipNode::AcceptAfter(IRNodePass *const p, bool *const modified) { | |||||
| // Downcast shared pointer then call visitor | // Downcast shared pointer then call visitor | ||||
| return p->VisitAfter(shared_from_base<SkipNode>(), modified); | return p->VisitAfter(shared_from_base<SkipNode>(), modified); | ||||
| } | } | ||||
| Status SkipNode::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["count"] = skip_count_; | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -80,6 +80,14 @@ class SkipNode : public DatasetNode { | |||||
| /// \return Status of the node visit | /// \return Status of the node visit | ||||
| Status AcceptAfter(IRNodePass *const p, bool *const modified) override; | Status AcceptAfter(IRNodePass *const p, bool *const modified) override; | ||||
| /// \brief Getter functions | |||||
| int32_t SkipCount() const { return skip_count_; } | |||||
| /// \brief Get the arguments of node | |||||
| /// \param[out] out_json JSON string of all attributes | |||||
| /// \return Status of the function | |||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| int32_t skip_count_; | int32_t skip_count_; | ||||
| }; | }; | ||||
| @@ -61,6 +61,12 @@ class AlbumNode : public MappableSourceNode { | |||||
| /// \return Status Status::OK() if get shard id successfully | /// \return Status Status::OK() if get shard id successfully | ||||
| Status GetShardId(int32_t *shard_id) override; | Status GetShardId(int32_t *shard_id) override; | ||||
| /// \brief Getter functions | |||||
| const std::string &DatasetDir() const { return dataset_dir_; } | |||||
| const std::string &SchemaPath() const { return schema_path_; } | |||||
| const std::vector<std::string> &ColumnNames() const { return column_names_; } | |||||
| bool Decode() const { return decode_; } | |||||
| private: | private: | ||||
| std::string dataset_dir_; | std::string dataset_dir_; | ||||
| std::string schema_path_; | std::string schema_path_; | ||||
| @@ -144,5 +144,22 @@ Status CelebANode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CelebANode::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args, sampler_args; | |||||
| RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); | |||||
| args["sampler"] = sampler_args; | |||||
| args["num_parallel_workers"] = num_workers_; | |||||
| args["dataset_dir"] = dataset_dir_; | |||||
| args["decode"] = decode_; | |||||
| args["extensions"] = extensions_; | |||||
| args["usage"] = usage_; | |||||
| if (cache_ != nullptr) { | |||||
| nlohmann::json cache_args; | |||||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||||
| args["cache"] = cache_args; | |||||
| } | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -71,6 +71,17 @@ class CelebANode : public MappableSourceNode { | |||||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | ||||
| int64_t *dataset_size) override; | int64_t *dataset_size) override; | ||||
| /// \brief Getter functions | |||||
| const std::string &DatasetDir() const { return dataset_dir_; } | |||||
| const std::string &Usage() const { return usage_; } | |||||
| bool Decode() const { return decode_; } | |||||
| const std::set<std::string> &Extensions() const { return extensions_; } | |||||
| /// \brief Get the arguments of node | |||||
| /// \param[out] out_json JSON string of all attributes | |||||
| /// \return Status of the function | |||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| std::string dataset_dir_; | std::string dataset_dir_; | ||||
| std::string usage_; | std::string usage_; | ||||
| @@ -95,5 +95,20 @@ Status Cifar100Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Cifar100Node::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args, sampler_args; | |||||
| RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); | |||||
| args["sampler"] = sampler_args; | |||||
| args["num_parallel_workers"] = num_workers_; | |||||
| args["dataset_dir"] = dataset_dir_; | |||||
| args["usage"] = usage_; | |||||
| if (cache_ != nullptr) { | |||||
| nlohmann::json cache_args; | |||||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||||
| args["cache"] = cache_args; | |||||
| } | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -69,6 +69,15 @@ class Cifar100Node : public MappableSourceNode { | |||||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | ||||
| int64_t *dataset_size) override; | int64_t *dataset_size) override; | ||||
| /// \brief Getter functions | |||||
| const std::string &DatasetDir() const { return dataset_dir_; } | |||||
| const std::string &Usage() const { return usage_; } | |||||
| /// \brief Get the arguments of node | |||||
| /// \param[out] out_json JSON string of all attributes | |||||
| /// \return Status of the function | |||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| std::string dataset_dir_; | std::string dataset_dir_; | ||||
| std::string usage_; | std::string usage_; | ||||
| @@ -93,5 +93,20 @@ Status Cifar10Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Cifar10Node::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args, sampler_args; | |||||
| RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); | |||||
| args["sampler"] = sampler_args; | |||||
| args["num_parallel_workers"] = num_workers_; | |||||
| args["dataset_dir"] = dataset_dir_; | |||||
| args["usage"] = usage_; | |||||
| if (cache_ != nullptr) { | |||||
| nlohmann::json cache_args; | |||||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||||
| args["cache"] = cache_args; | |||||
| } | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -69,6 +69,15 @@ class Cifar10Node : public MappableSourceNode { | |||||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | ||||
| int64_t *dataset_size) override; | int64_t *dataset_size) override; | ||||
| /// \brief Getter functions | |||||
| const std::string &DatasetDir() const { return dataset_dir_; } | |||||
| const std::string &Usage() const { return usage_; } | |||||
| /// \brief Get the arguments of node | |||||
| /// \param[out] out_json JSON string of all attributes | |||||
| /// \return Status of the function | |||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| std::string dataset_dir_; | std::string dataset_dir_; | ||||
| std::string usage_; | std::string usage_; | ||||
| @@ -252,5 +252,23 @@ Status CLUENode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_g | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CLUENode::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["num_parallel_workers"] = num_workers_; | |||||
| args["dataset_dir"] = dataset_files_; | |||||
| args["task"] = task_; | |||||
| args["usage"] = usage_; | |||||
| args["num_samples"] = num_samples_; | |||||
| args["shuffle"] = shuffle_; | |||||
| args["num_shards"] = num_shards_; | |||||
| args["shard_id"] = shard_id_; | |||||
| if (cache_ != nullptr) { | |||||
| nlohmann::json cache_args; | |||||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||||
| args["cache"] = cache_args; | |||||
| } | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -71,6 +71,20 @@ class CLUENode : public NonMappableSourceNode { | |||||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | ||||
| int64_t *dataset_size) override; | int64_t *dataset_size) override; | ||||
| /// \brief Getter functions | |||||
| const std::vector<std::string> &DatasetFiles() const { return dataset_files_; } | |||||
| const std::string &Task() const { return task_; } | |||||
| const std::string &Usage() const { return usage_; } | |||||
| int64_t NumSamples() const { return num_samples_; } | |||||
| ShuffleMode Shuffle() const { return shuffle_; } | |||||
| int32_t NumShards() const { return num_shards_; } | |||||
| int32_t ShardId() const { return shard_id_; } | |||||
| /// \brief Get the arguments of node | |||||
| /// \param[out] out_json JSON string of all attributes | |||||
| /// \return Status of the function | |||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| /// \brief Split string based on a character delimiter | /// \brief Split string based on a character delimiter | ||||
| /// \return A string vector | /// \return A string vector | ||||
| @@ -151,5 +151,22 @@ Status CocoNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_g | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CocoNode::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args, sampler_args; | |||||
| RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); | |||||
| args["sampler"] = sampler_args; | |||||
| args["num_parallel_workers"] = num_workers_; | |||||
| args["dataset_dir"] = dataset_dir_; | |||||
| args["annotation_file"] = annotation_file_; | |||||
| args["task"] = task_; | |||||
| args["decode"] = decode_; | |||||
| if (cache_ != nullptr) { | |||||
| nlohmann::json cache_args; | |||||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||||
| args["cache"] = cache_args; | |||||
| } | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -69,6 +69,17 @@ class CocoNode : public MappableSourceNode { | |||||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | ||||
| int64_t *dataset_size) override; | int64_t *dataset_size) override; | ||||
| /// \brief Getter functions | |||||
| const std::string &DatasetDir() const { return dataset_dir_; } | |||||
| const std::string &AnnotationFile() const { return annotation_file_; } | |||||
| const std::string &Task() const { return task_; } | |||||
| bool Decode() const { return decode_; } | |||||
| /// \brief Get the arguments of node | |||||
| /// \param[out] out_json JSON string of all attributes | |||||
| /// \return Status of the function | |||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| std::string dataset_dir_; | std::string dataset_dir_; | ||||
| std::string annotation_file_; | std::string annotation_file_; | ||||
| @@ -170,5 +170,23 @@ Status CSVNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_ge | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CSVNode::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["num_parallel_workers"] = num_workers_; | |||||
| args["dataset_files"] = dataset_files_; | |||||
| args["field_delim"] = std::string(1, field_delim_); | |||||
| args["column_names"] = column_names_; | |||||
| args["num_samples"] = num_samples_; | |||||
| args["shuffle"] = shuffle_; | |||||
| args["num_shards"] = num_shards_; | |||||
| args["shard_id"] = shard_id_; | |||||
| if (cache_ != nullptr) { | |||||
| nlohmann::json cache_args; | |||||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||||
| args["cache"] = cache_args; | |||||
| } | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -92,6 +92,21 @@ class CSVNode : public NonMappableSourceNode { | |||||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | ||||
| int64_t *dataset_size) override; | int64_t *dataset_size) override; | ||||
| /// \brief Getter functions | |||||
| const std::vector<std::string> &DatasetFiles() const { return dataset_files_; } | |||||
| char FieldDelim() const { return field_delim_; } | |||||
| const std::vector<std::shared_ptr<CsvBase>> &ColumnDefaults() const { return column_defaults_; } | |||||
| const std::vector<std::string> &ColumnNames() const { return column_names_; } | |||||
| int64_t NumSamples() const { return num_samples_; } | |||||
| ShuffleMode Shuffle() const { return shuffle_; } | |||||
| int32_t NumShards() const { return num_shards_; } | |||||
| int32_t ShardId() const { return shard_id_; } | |||||
| /// \brief Get the arguments of node | |||||
| /// \param[out] out_json JSON string of all attributes | |||||
| /// \return Status of the function | |||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| std::vector<std::string> dataset_files_; | std::vector<std::string> dataset_files_; | ||||
| char field_delim_; | char field_delim_; | ||||
| @@ -83,6 +83,12 @@ class GeneratorNode : public MappableSourceNode { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| /// \brief Getter functions | |||||
| const py::function &GeneratorFunction() const { return generator_function_; } | |||||
| const std::vector<std::string> &ColumnNames() const { return column_names_; } | |||||
| const std::vector<DataType> &ColumnTypes() const { return column_types_; } | |||||
| const std::shared_ptr<SchemaObj> &Schema() const { return schema_; } | |||||
| private: | private: | ||||
| py::function generator_function_; | py::function generator_function_; | ||||
| std::vector<std::string> column_names_; | std::vector<std::string> column_names_; | ||||
| @@ -124,5 +124,23 @@ Status ManifestNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status ManifestNode::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args, sampler_args; | |||||
| RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); | |||||
| args["sampler"] = sampler_args; | |||||
| args["num_parallel_workers"] = num_workers_; | |||||
| args["dataset_file"] = dataset_file_; | |||||
| args["usage"] = usage_; | |||||
| args["class_indexing"] = class_index_; | |||||
| args["decode"] = decode_; | |||||
| if (cache_ != nullptr) { | |||||
| nlohmann::json cache_args; | |||||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||||
| args["cache"] = cache_args; | |||||
| } | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -70,6 +70,17 @@ class ManifestNode : public MappableSourceNode { | |||||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | ||||
| int64_t *dataset_size) override; | int64_t *dataset_size) override; | ||||
| /// \brief Getter functions | |||||
| const std::string &DatasetFile() const { return dataset_file_; } | |||||
| const std::string &Usage() const { return usage_; } | |||||
| bool Decode() const { return decode_; } | |||||
| const std::map<std::string, int32_t> &ClassIndex() const { return class_index_; } | |||||
| /// \brief Get the arguments of node | |||||
| /// \param[out] out_json JSON string of all attributes | |||||
| /// \return Status of the function | |||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| std::string dataset_file_; | std::string dataset_file_; | ||||
| std::string usage_; | std::string usage_; | ||||
| @@ -91,6 +91,14 @@ class RandomNode : public NonMappableSourceNode { | |||||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | ||||
| int64_t *dataset_size) override; | int64_t *dataset_size) override; | ||||
| /// \brief Getter functions | |||||
| int32_t TotalRows() const { return total_rows_; } | |||||
| const std::string &SchemaPath() const { return schema_path_; } | |||||
| const std::shared_ptr<SchemaObj> &GetSchema() const { return schema_; } | |||||
| const std::vector<std::string> &ColumnsList() const { return columns_list_; } | |||||
| const std::mt19937 &RandGen() const { return rand_gen_; } | |||||
| const std::unique_ptr<DataSchema> &GetDataSchema() const { return data_schema_; } | |||||
| private: | private: | ||||
| /// \brief A quick inline for producing a random number between (and including) min/max | /// \brief A quick inline for producing a random number between (and including) min/max | ||||
| /// \param[in] min minimum number that can be generated. | /// \param[in] min minimum number that can be generated. | ||||
| @@ -136,5 +136,21 @@ Status TextFileNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status TextFileNode::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["num_parallel_workers"] = num_workers_; | |||||
| args["dataset_files"] = dataset_files_; | |||||
| args["num_samples"] = num_samples_; | |||||
| args["shuffle"] = shuffle_; | |||||
| args["num_shards"] = num_shards_; | |||||
| args["shard_id"] = shard_id_; | |||||
| if (cache_ != nullptr) { | |||||
| nlohmann::json cache_args; | |||||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||||
| args["cache"] = cache_args; | |||||
| } | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -71,6 +71,18 @@ class TextFileNode : public NonMappableSourceNode { | |||||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | ||||
| int64_t *dataset_size) override; | int64_t *dataset_size) override; | ||||
| /// \brief Getter functions | |||||
| const std::vector<std::string> &DatasetFiles() const { return dataset_files_; } | |||||
| int32_t NumSamples() const { return num_samples_; } | |||||
| int32_t NumShards() const { return num_shards_; } | |||||
| int32_t ShardId() const { return shard_id_; } | |||||
| ShuffleMode Shuffle() const { return shuffle_; } | |||||
| /// \brief Get the arguments of node | |||||
| /// \param[out] out_json JSON string of all attributes | |||||
| /// \return Status of the function | |||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| std::vector<std::string> dataset_files_; | std::vector<std::string> dataset_files_; | ||||
| int32_t num_samples_; | int32_t num_samples_; | ||||
| @@ -140,5 +140,23 @@ Status VOCNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_ge | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status VOCNode::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args, sampler_args; | |||||
| RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); | |||||
| args["sampler"] = sampler_args; | |||||
| args["num_parallel_workers"] = num_workers_; | |||||
| args["dataset_dir"] = dataset_dir_; | |||||
| args["task"] = task_; | |||||
| args["usage"] = usage_; | |||||
| args["class_indexing"] = class_index_; | |||||
| args["decode"] = decode_; | |||||
| if (cache_ != nullptr) { | |||||
| nlohmann::json cache_args; | |||||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||||
| args["cache"] = cache_args; | |||||
| } | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -71,6 +71,18 @@ class VOCNode : public MappableSourceNode { | |||||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | ||||
| int64_t *dataset_size) override; | int64_t *dataset_size) override; | ||||
| /// \brief Getter functions | |||||
| const std::string &DatasetDir() const { return dataset_dir_; } | |||||
| const std::string &Task() const { return task_; } | |||||
| const std::string &Usage() const { return usage_; } | |||||
| const std::map<std::string, int32_t> &ClassIndex() const { return class_index_; } | |||||
| bool Decode() const { return decode_; } | |||||
| /// \brief Get the arguments of node | |||||
| /// \param[out] out_json JSON string of all attributes | |||||
| /// \return Status of the function | |||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| const std::string kColumnImage = "image"; | const std::string kColumnImage = "image"; | ||||
| const std::string kColumnTarget = "target"; | const std::string kColumnTarget = "target"; | ||||
| @@ -57,6 +57,10 @@ class SyncWaitNode : public DatasetNode { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Getter functions | |||||
| const std::string &ConditionName() const { return condition_name_; } | |||||
| const py::function &Callback() const { return callback_; } | |||||
| private: | private: | ||||
| std::string condition_name_; | std::string condition_name_; | ||||
| py::function callback_; | py::function callback_; | ||||
| @@ -81,5 +81,12 @@ Status TakeNode::AcceptAfter(IRNodePass *const p, bool *const modified) { | |||||
| // Downcast shared pointer then call visitor | // Downcast shared pointer then call visitor | ||||
| return p->VisitAfter(shared_from_base<TakeNode>(), modified); | return p->VisitAfter(shared_from_base<TakeNode>(), modified); | ||||
| } | } | ||||
| Status TakeNode::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["count"] = take_count_; | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -80,6 +80,14 @@ class TakeNode : public DatasetNode { | |||||
| /// \return Status of the node visit | /// \return Status of the node visit | ||||
| Status AcceptAfter(IRNodePass *const p, bool *const modified) override; | Status AcceptAfter(IRNodePass *const p, bool *const modified) override; | ||||
| /// \brief Getter functions | |||||
| int32_t TakeCount() const { return take_count_; } | |||||
| /// \brief Get the arguments of node | |||||
| /// \param[out] out_json JSON string of all attributes | |||||
| /// \return Status of the function | |||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| int32_t take_count_; | int32_t take_count_; | ||||
| }; | }; | ||||
| @@ -116,5 +116,14 @@ Status TransferNode::AcceptAfter(IRNodePass *const p, bool *const modified) { | |||||
| // Downcast shared pointer then call visitor | // Downcast shared pointer then call visitor | ||||
| return p->VisitAfter(shared_from_base<TransferNode>(), modified); | return p->VisitAfter(shared_from_base<TransferNode>(), modified); | ||||
| } | } | ||||
| Status TransferNode::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["send_epoch_end"] = send_epoch_end_; | |||||
| args["total_batch"] = total_batch_; | |||||
| args["create_data_info_queue"] = create_data_info_queue_; | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -70,6 +70,20 @@ class TransferNode : public DatasetNode { | |||||
| /// \return Status of the node visit | /// \return Status of the node visit | ||||
| Status AcceptAfter(IRNodePass *const p, bool *const modified) override; | Status AcceptAfter(IRNodePass *const p, bool *const modified) override; | ||||
| /// \brief Getter functions | |||||
| const std::string &QueueName() const { return queue_name_; } | |||||
| int32_t DeviceId() const { return device_id_; } | |||||
| const std::string &DeviceType() const { return device_type_; } | |||||
| int32_t PrefetchSize() const { return prefetch_size_; } | |||||
| bool SendEpochEnd() const { return send_epoch_end_; } | |||||
| int32_t TotalBatch() const { return total_batch_; } | |||||
| bool CreateDataInfoQueue() const { return create_data_info_queue_; } | |||||
| /// \brief Get the arguments of node | |||||
| /// \param[out] out_json JSON string of all attributes | |||||
| /// \return Status of the function | |||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| std::string queue_name_; | std::string queue_name_; | ||||
| int32_t device_id_; | int32_t device_id_; | ||||
| @@ -668,6 +668,8 @@ class PadOperation : public TensorOperation { | |||||
| std::string Name() const override { return kPadOperation; } | std::string Name() const override { return kPadOperation; } | ||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| std::vector<int32_t> padding_; | std::vector<int32_t> padding_; | ||||
| std::vector<uint8_t> fill_value_; | std::vector<uint8_t> fill_value_; | ||||
| @@ -729,6 +731,8 @@ class RandomColorAdjustOperation : public TensorOperation { | |||||
| std::string Name() const override { return kRandomColorAdjustOperation; } | std::string Name() const override { return kRandomColorAdjustOperation; } | ||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| std::vector<float> brightness_; | std::vector<float> brightness_; | ||||
| std::vector<float> contrast_; | std::vector<float> contrast_; | ||||
| @@ -750,6 +754,8 @@ class RandomCropOperation : public TensorOperation { | |||||
| std::string Name() const override { return kRandomCropOperation; } | std::string Name() const override { return kRandomCropOperation; } | ||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| std::vector<int32_t> size_; | std::vector<int32_t> size_; | ||||
| std::vector<int32_t> padding_; | std::vector<int32_t> padding_; | ||||
| @@ -936,6 +942,8 @@ class RandomRotationOperation : public TensorOperation { | |||||
| std::string Name() const override { return kRandomRotationOperation; } | std::string Name() const override { return kRandomRotationOperation; } | ||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| std::vector<float> degrees_; | std::vector<float> degrees_; | ||||
| InterpolationMode interpolation_mode_; | InterpolationMode interpolation_mode_; | ||||
| @@ -1037,6 +1045,8 @@ class RescaleOperation : public TensorOperation { | |||||
| std::string Name() const override { return kRescaleOperation; } | std::string Name() const override { return kRescaleOperation; } | ||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| float rescale_; | float rescale_; | ||||
| float shift_; | float shift_; | ||||
| @@ -105,6 +105,8 @@ class CenterCropOperation : public TensorOperation { | |||||
| std::string Name() const override { return kCenterCropOperation; } | std::string Name() const override { return kCenterCropOperation; } | ||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| std::vector<int32_t> size_; | std::vector<int32_t> size_; | ||||
| }; | }; | ||||
| @@ -137,6 +139,8 @@ class DecodeOperation : public TensorOperation { | |||||
| std::string Name() const override { return kDecodeOperation; } | std::string Name() const override { return kDecodeOperation; } | ||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| bool rgb_; | bool rgb_; | ||||
| }; | }; | ||||
| @@ -153,6 +157,8 @@ class NormalizeOperation : public TensorOperation { | |||||
| std::string Name() const override { return kNormalizeOperation; } | std::string Name() const override { return kNormalizeOperation; } | ||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| std::vector<float> mean_; | std::vector<float> mean_; | ||||
| std::vector<float> std_; | std::vector<float> std_; | ||||
| @@ -171,6 +177,8 @@ class ResizeOperation : public TensorOperation { | |||||
| std::string Name() const override { return kResizeOperation; } | std::string Name() const override { return kResizeOperation; } | ||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| std::vector<int32_t> size_; | std::vector<int32_t> size_; | ||||
| InterpolationMode interpolation_; | InterpolationMode interpolation_; | ||||
| @@ -33,5 +33,12 @@ Status TypeCastOp::OutputType(const std::vector<DataType> &inputs, std::vector<D | |||||
| outputs[0] = type_; | outputs[0] = type_; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status TypeCastOp::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["data_type"] = type_.ToString(); | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -43,6 +43,8 @@ class TypeCastOp : public TensorOp { | |||||
| std::string Name() const override { return kTypeCastOp; } | std::string Name() const override { return kTypeCastOp; } | ||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| DataType type_; | DataType type_; | ||||
| }; | }; | ||||
| @@ -61,11 +61,5 @@ Status DecodeOp::OutputType(const std::vector<DataType> &inputs, std::vector<Dat | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status DecodeOp::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["rgb"] = is_rgb_format_; | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -42,8 +42,6 @@ class DecodeOp : public TensorOp { | |||||
| std::string Name() const override { return kDecodeOp; } | std::string Name() const override { return kDecodeOp; } | ||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| bool is_rgb_format_ = true; | bool is_rgb_format_ = true; | ||||
| }; | }; | ||||
| @@ -138,16 +138,5 @@ Status RandomCropOp::OutputShape(const std::vector<TensorShape> &inputs, std::ve | |||||
| if (!outputs.empty()) return Status::OK(); | if (!outputs.empty()) return Status::OK(); | ||||
| return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); | return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); | ||||
| } | } | ||||
| Status RandomCropOp::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["size"] = std::vector<int32_t>{crop_height_, crop_width_}; | |||||
| args["padding"] = std::vector<int32_t>{pad_top_, pad_bottom_, pad_left_, pad_right_}; | |||||
| args["pad_if_needed"] = pad_if_needed_; | |||||
| args["fill_value"] = std::tuple<uint8_t, uint8_t, uint8_t>{fill_r_, fill_g_, fill_b_}; | |||||
| args["padding_mode"] = border_type_; | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -79,8 +79,6 @@ class RandomCropOp : public TensorOp { | |||||
| std::string Name() const override { return kRandomCropOp; } | std::string Name() const override { return kRandomCropOp; } | ||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| protected: | protected: | ||||
| int32_t crop_height_ = 0; | int32_t crop_height_ = 0; | ||||
| int32_t crop_width_ = 0; | int32_t crop_width_ = 0; | ||||
| @@ -29,13 +29,5 @@ Status RescaleOp::OutputType(const std::vector<DataType> &inputs, std::vector<Da | |||||
| outputs[0] = DataType(DataType::DE_FLOAT32); | outputs[0] = DataType(DataType::DE_FLOAT32); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status RescaleOp::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["rescale"] = rescale_; | |||||
| args["shift"] = shift_; | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -41,8 +41,6 @@ class RescaleOp : public TensorOp { | |||||
| std::string Name() const override { return kRescaleOp; } | std::string Name() const override { return kRescaleOp; } | ||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| float rescale_; | float rescale_; | ||||
| float shift_; | float shift_; | ||||
| @@ -67,13 +67,5 @@ Status ResizeOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector | |||||
| if (!outputs.empty()) return Status::OK(); | if (!outputs.empty()) return Status::OK(); | ||||
| return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); | return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); | ||||
| } | } | ||||
| Status ResizeOp::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["size"] = std::vector<int32_t>{size1_, size2_}; | |||||
| args["interpolation"] = interpolation_; | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -61,8 +61,6 @@ class ResizeOp : public TensorOp { | |||||
| std::string Name() const override { return kResizeOp; } | std::string Name() const override { return kResizeOp; } | ||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| protected: | protected: | ||||
| int32_t size1_; | int32_t size1_; | ||||
| int32_t size2_; | int32_t size2_; | ||||
| @@ -167,19 +167,82 @@ def create_node(node): | |||||
| pyobj = None | pyobj = None | ||||
| # Find a matching Dataset class and call the constructor with the corresponding args. | # Find a matching Dataset class and call the constructor with the corresponding args. | ||||
| # When a new Dataset class is introduced, another if clause and parsing code needs to be added. | # When a new Dataset class is introduced, another if clause and parsing code needs to be added. | ||||
| if dataset_op == 'ImageFolderDataset': | |||||
| # Dataset Source Ops (in alphabetical order) | |||||
| if dataset_op == 'CelebADataset': | |||||
| sampler = construct_sampler(node.get('sampler')) | |||||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||||
| pyobj = pyclass(node['dataset_dir'], node.get('num_parallel_workers'), node.get('shuffle'), node.get('usage'), | |||||
| sampler, node.get('decode'), node.get('extensions'), num_samples, node.get('num_shards'), | |||||
| node.get('shard_id')) | |||||
| elif dataset_op == 'Cifar10Dataset': | |||||
| sampler = construct_sampler(node.get('sampler')) | |||||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||||
| pyobj = pyclass(node['dataset_dir'], node['usage'], num_samples, node.get('num_parallel_workers'), | |||||
| node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id')) | |||||
| elif dataset_op == 'Cifar100Dataset': | |||||
| sampler = construct_sampler(node.get('sampler')) | |||||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||||
| pyobj = pyclass(node['dataset_dir'], node['usage'], num_samples, node.get('num_parallel_workers'), | |||||
| node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id')) | |||||
| elif dataset_op == 'ClueDataset': | |||||
| shuffle = to_shuffle_mode(node.get('shuffle')) | |||||
| if shuffle is not None and isinstance(shuffle, str): | |||||
| shuffle = de.Shuffle(shuffle) | |||||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||||
| pyobj = pyclass(node['dataset_files'], node.get('task'), | |||||
| node.get('usage'), num_samples, node.get('num_parallel_workers'), shuffle, | |||||
| node.get('num_shards'), node.get('shard_id')) | |||||
| elif dataset_op == 'CocoDataset': | |||||
| sampler = construct_sampler(node.get('sampler')) | |||||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||||
| pyobj = pyclass(node['dataset_dir'], node.get('annotation_file'), node.get('task'), num_samples, | |||||
| node.get('num_parallel_workers'), node.get('shuffle'), node.get('decode'), sampler, | |||||
| node.get('num_shards'), node.get('shard_id')) | |||||
| elif dataset_op == 'CSVDataset': | |||||
| shuffle = to_shuffle_mode(node.get('shuffle')) | |||||
| if shuffle is not None and isinstance(shuffle, str): | |||||
| shuffle = de.Shuffle(shuffle) | |||||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||||
| pyobj = pyclass(node['dataset_files'], node.get('field_delim'), | |||||
| node.get('column_defaults'), node.get('column_names'), num_samples, | |||||
| node.get('num_parallel_workers'), shuffle, | |||||
| node.get('num_shards'), node.get('shard_id')) | |||||
| elif dataset_op == 'ImageFolderDataset': | |||||
| sampler = construct_sampler(node.get('sampler')) | sampler = construct_sampler(node.get('sampler')) | ||||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | ||||
| pyobj = pyclass(node['dataset_dir'], num_samples, node.get('num_parallel_workers'), | pyobj = pyclass(node['dataset_dir'], num_samples, node.get('num_parallel_workers'), | ||||
| node.get('shuffle'), sampler, node.get('extensions'), | node.get('shuffle'), sampler, node.get('extensions'), | ||||
| node.get('class_indexing'), node.get('decode'), node.get('num_shards'), | node.get('class_indexing'), node.get('decode'), node.get('num_shards'), | ||||
| node.get('shard_id'), node.get('cache')) | |||||
| node.get('shard_id')) | |||||
| elif dataset_op == 'ManifestDataset': | |||||
| sampler = construct_sampler(node.get('sampler')) | |||||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||||
| pyobj = pyclass(node['dataset_file'], node['usage'], num_samples, | |||||
| node.get('num_parallel_workers'), node.get('shuffle'), sampler, | |||||
| node.get('class_indexing'), node.get('decode'), node.get('num_shards'), | |||||
| node.get('shard_id')) | |||||
| elif dataset_op == 'MnistDataset': | elif dataset_op == 'MnistDataset': | ||||
| sampler = construct_sampler(node.get('sampler')) | sampler = construct_sampler(node.get('sampler')) | ||||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | ||||
| pyobj = pyclass(node['dataset_dir'], node['usage'], num_samples, node.get('num_parallel_workers'), | pyobj = pyclass(node['dataset_dir'], node['usage'], num_samples, node.get('num_parallel_workers'), | ||||
| node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'), node.get('cache')) | |||||
| node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id')) | |||||
| elif dataset_op == 'TextFileDataset': | |||||
| shuffle = to_shuffle_mode(node.get('shuffle')) | |||||
| if shuffle is not None and isinstance(shuffle, str): | |||||
| shuffle = de.Shuffle(shuffle) | |||||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||||
| pyobj = pyclass(node['dataset_files'], num_samples, | |||||
| node.get('num_parallel_workers'), shuffle, | |||||
| node.get('num_shards'), node.get('shard_id')) | |||||
| elif dataset_op == 'TFRecordDataset': | elif dataset_op == 'TFRecordDataset': | ||||
| shuffle = to_shuffle_mode(node.get('shuffle')) | shuffle = to_shuffle_mode(node.get('shuffle')) | ||||
| @@ -188,30 +251,50 @@ def create_node(node): | |||||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | ||||
| pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('columns_list'), | pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('columns_list'), | ||||
| num_samples, node.get('num_parallel_workers'), | num_samples, node.get('num_parallel_workers'), | ||||
| shuffle, node.get('num_shards'), node.get('shard_id'), node.get('cache')) | |||||
| shuffle, node.get('num_shards'), node.get('shard_id')) | |||||
| elif dataset_op == 'Repeat': | |||||
| pyobj = de.Dataset().repeat(node.get('count')) | |||||
| elif dataset_op == 'VOCDataset': | |||||
| sampler = construct_sampler(node.get('sampler')) | |||||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||||
| pyobj = pyclass(node['dataset_dir'], node.get('task'), node.get('usage'), node.get('class_indexing'), | |||||
| num_samples, node.get('num_parallel_workers'), node.get('shuffle'), | |||||
| node.get('decode'), sampler, node.get('num_shards'), node.get('shard_id')) | |||||
| # Dataset Ops (in alphabetical order) | |||||
| elif dataset_op == 'Batch': | |||||
| pyobj = de.Dataset().batch(node['batch_size'], node.get('drop_remainder')) | |||||
| elif dataset_op == 'Map': | elif dataset_op == 'Map': | ||||
| tensor_ops = construct_tensor_ops(node.get('operations')) | tensor_ops = construct_tensor_ops(node.get('operations')) | ||||
| pyobj = de.Dataset().map(tensor_ops, node.get('input_columns'), node.get('output_columns'), | pyobj = de.Dataset().map(tensor_ops, node.get('input_columns'), node.get('output_columns'), | ||||
| node.get('column_order'), node.get('num_parallel_workers'), | node.get('column_order'), node.get('num_parallel_workers'), | ||||
| True, node.get('cache'), node.get('callbacks')) | |||||
| True, node.get('callbacks')) | |||||
| elif dataset_op == 'Project': | |||||
| pyobj = de.Dataset().project(node['columns']) | |||||
| elif dataset_op == 'Rename': | |||||
| pyobj = de.Dataset().rename(node['input_columns'], node['output_columns']) | |||||
| elif dataset_op == 'Repeat': | |||||
| pyobj = de.Dataset().repeat(node.get('count')) | |||||
| elif dataset_op == 'Shuffle': | elif dataset_op == 'Shuffle': | ||||
| pyobj = de.Dataset().shuffle(node.get('buffer_size')) | pyobj = de.Dataset().shuffle(node.get('buffer_size')) | ||||
| elif dataset_op == 'Batch': | |||||
| pyobj = de.Dataset().batch(node['batch_size'], node.get('drop_remainder')) | |||||
| elif dataset_op == 'Skip': | |||||
| pyobj = de.Dataset().skip(node.get('count')) | |||||
| elif dataset_op == 'Take': | |||||
| pyobj = de.Dataset().take(node.get('count')) | |||||
| elif dataset_op == 'Transfer': | |||||
| pyobj = de.Dataset().to_device(node.get('send_epoch_end'), node.get('create_data_info_queue')) | |||||
| elif dataset_op == 'Zip': | elif dataset_op == 'Zip': | ||||
| # Create ZipDataset instance, giving dummy input dataset that will be overrided in the caller. | # Create ZipDataset instance, giving dummy input dataset that will be overrided in the caller. | ||||
| pyobj = de.ZipDataset((de.Dataset(), de.Dataset())) | pyobj = de.ZipDataset((de.Dataset(), de.Dataset())) | ||||
| elif dataset_op == 'Rename': | |||||
| pyobj = de.Dataset().rename(node['input_columns'], node['output_columns']) | |||||
| else: | else: | ||||
| raise RuntimeError(dataset_op + " is not yet supported by ds.engine.deserialize().") | raise RuntimeError(dataset_op + " is not yet supported by ds.engine.deserialize().") | ||||
| @@ -252,35 +335,59 @@ def construct_tensor_ops(operations): | |||||
| """Instantiate tensor op object(s) based on the information from dictionary['operations']""" | """Instantiate tensor op object(s) based on the information from dictionary['operations']""" | ||||
| result = [] | result = [] | ||||
| for op in operations: | for op in operations: | ||||
| op_name = op['tensor_op_name'][:-2] # to remove op from the back of the name | |||||
| op_name = op['tensor_op_name'] | |||||
| op_module_vis = sys.modules["mindspore.dataset.vision.c_transforms"] | op_module_vis = sys.modules["mindspore.dataset.vision.c_transforms"] | ||||
| op_module_trans = sys.modules["mindspore.dataset.transforms.c_transforms"] | op_module_trans = sys.modules["mindspore.dataset.transforms.c_transforms"] | ||||
| if op_name == "HwcToChw": op_name = "HWC2CHW" | |||||
| if hasattr(op_module_vis, op_name): | if hasattr(op_module_vis, op_name): | ||||
| op_class = getattr(op_module_vis, op_name) | op_class = getattr(op_module_vis, op_name) | ||||
| elif hasattr(op_module_trans, op_name): | |||||
| elif hasattr(op_module_trans, op_name[:-2]): | |||||
| op_name = op_name[:-2] # to remove op from the back of the name | |||||
| op_class = getattr(op_module_trans, op_name) | op_class = getattr(op_module_trans, op_name) | ||||
| else: | else: | ||||
| raise RuntimeError(op_name + " is not yet supported by deserialize().") | raise RuntimeError(op_name + " is not yet supported by deserialize().") | ||||
| if op_name == 'Decode': | |||||
| # Transforms Ops (in alphabetical order) | |||||
| if op_name == 'OneHot': | |||||
| result.append(op_class(op['num_classes'])) | |||||
| elif op_name == 'TypeCast': | |||||
| result.append(op_class(to_mstype(op['data_type']))) | |||||
| # Vision Ops (in alphabetical order) | |||||
| elif op_name == 'CenterCrop': | |||||
| result.append(op_class(op['size'])) | |||||
| elif op_name == 'Decode': | |||||
| result.append(op_class(op.get('rgb'))) | result.append(op_class(op.get('rgb'))) | ||||
| elif op_name == 'HWC2CHW': | |||||
| result.append(op_class()) | |||||
| elif op_name == 'Normalize': | |||||
| result.append(op_class(op['mean'], op['std'])) | |||||
| elif op_name == 'Pad': | |||||
| result.append(op_class(op['padding'], tuple(op['fill_value']), Border(to_border_mode(op['padding_mode'])))) | |||||
| elif op_name == 'RandomColorAdjust': | |||||
| result.append(op_class(op.get('brightness'), op.get('contrast'), op.get('saturation'), | |||||
| op.get('hue'))) | |||||
| elif op_name == 'RandomCrop': | elif op_name == 'RandomCrop': | ||||
| result.append(op_class(op['size'], op.get('padding'), op.get('pad_if_needed'), | result.append(op_class(op['size'], op.get('padding'), op.get('pad_if_needed'), | ||||
| tuple(op.get('fill_value')), Border(to_border_mode(op.get('padding_mode'))))) | tuple(op.get('fill_value')), Border(to_border_mode(op.get('padding_mode'))))) | ||||
| elif op_name == 'Resize': | |||||
| result.append(op_class(op['size'], Inter(to_interpolation_mode(op.get('interpolation'))))) | |||||
| elif op_name == 'RandomRotation': | |||||
| result.append(op_class(op['degrees'], to_interpolation_mode(op.get('interpolation_mode')), op.get('expand'), | |||||
| tuple(op.get('center')), tuple(op.get('fill_value')))) | |||||
| elif op_name == 'Rescale': | elif op_name == 'Rescale': | ||||
| result.append(op_class(op['rescale'], op['shift'])) | result.append(op_class(op['rescale'], op['shift'])) | ||||
| elif op_name == 'HWC2CHW': | |||||
| result.append(op_class()) | |||||
| elif op_name == 'OneHot': | |||||
| result.append(op_class(op['num_classes'])) | |||||
| elif op_name == 'Resize': | |||||
| result.append(op_class(op['size'], to_interpolation_mode(op.get('interpolation')))) | |||||
| else: | else: | ||||
| raise ValueError("Tensor op name is unknown: {}.".format(op_name)) | raise ValueError("Tensor op name is unknown: {}.".format(op_name)) | ||||
| @@ -19,11 +19,13 @@ import filecmp | |||||
| import glob | import glob | ||||
| import json | import json | ||||
| import os | import os | ||||
| import pytest | |||||
| import numpy as np | import numpy as np | ||||
| from test_minddataset_sampler import add_and_remove_cv_file, get_data, CV_DIR_NAME, CV_FILE_NAME | from test_minddataset_sampler import add_and_remove_cv_file, get_data, CV_DIR_NAME, CV_FILE_NAME | ||||
| from util import config_get_set_num_parallel_workers, config_get_set_seed | from util import config_get_set_num_parallel_workers, config_get_set_seed | ||||
| import mindspore.common.dtype as mstype | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| import mindspore.dataset.transforms.c_transforms as c | import mindspore.dataset.transforms.c_transforms as c | ||||
| import mindspore.dataset.vision.c_transforms as vision | import mindspore.dataset.vision.c_transforms as vision | ||||
| @@ -31,7 +33,7 @@ from mindspore import log as logger | |||||
| from mindspore.dataset.vision import Inter | from mindspore.dataset.vision import Inter | ||||
| def skip_test_imagefolder(remove_json_files=True): | |||||
| def test_serdes_imagefolder_dataset(remove_json_files=True): | |||||
| """ | """ | ||||
| Test simulating resnet50 dataset pipeline. | Test simulating resnet50 dataset pipeline. | ||||
| """ | """ | ||||
| @@ -100,7 +102,7 @@ def skip_test_imagefolder(remove_json_files=True): | |||||
| delete_json_files() | delete_json_files() | ||||
| def test_mnist_dataset(remove_json_files=True): | |||||
| def test_serdes_mnist_dataset(remove_json_files=True): | |||||
| """ | """ | ||||
| Test serdes on mnist dataset pipeline. | Test serdes on mnist dataset pipeline. | ||||
| """ | """ | ||||
| @@ -141,7 +143,7 @@ def test_mnist_dataset(remove_json_files=True): | |||||
| delete_json_files() | delete_json_files() | ||||
| def test_zip_dataset(remove_json_files=True): | |||||
| def test_serdes_zip_dataset(remove_json_files=True): | |||||
| """ | """ | ||||
| Test serdes on zip dataset pipeline. | Test serdes on zip dataset pipeline. | ||||
| """ | """ | ||||
| @@ -185,7 +187,7 @@ def test_zip_dataset(remove_json_files=True): | |||||
| delete_json_files() | delete_json_files() | ||||
| def skip_test_random_crop(): | |||||
| def test_serdes_random_crop(): | |||||
| """ | """ | ||||
| Test serdes on RandomCrop pipeline. | Test serdes on RandomCrop pipeline. | ||||
| """ | """ | ||||
| @@ -225,6 +227,179 @@ def skip_test_random_crop(): | |||||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | ds.config.set_num_parallel_workers(original_num_parallel_workers) | ||||
| def test_serdes_cifar10_dataset(remove_json_files=True): | |||||
| """ | |||||
| Test serdes on Cifar10 dataset pipeline | |||||
| """ | |||||
| data_dir = "../data/dataset/testCifar10Data" | |||||
| original_seed = config_get_set_seed(1) | |||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||||
| data1 = ds.Cifar10Dataset(data_dir, num_samples=10, shuffle=False) | |||||
| data1 = data1.take(6) | |||||
| trans = [ | |||||
| vision.RandomCrop((32, 32), (4, 4, 4, 4)), | |||||
| vision.Resize((224, 224)), | |||||
| vision.Rescale(1.0 / 255.0, 0.0), | |||||
| vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]), | |||||
| vision.HWC2CHW() | |||||
| ] | |||||
| type_cast_op = c.TypeCast(mstype.int32) | |||||
| data1 = data1.map(operations=type_cast_op, input_columns="label") | |||||
| data1 = data1.map(operations=trans, input_columns="image") | |||||
| data1 = data1.batch(3, drop_remainder=True) | |||||
| data1 = data1.repeat(1) | |||||
| data2 = util_check_serialize_deserialize_file(data1, "cifar10_dataset_pipeline", remove_json_files) | |||||
| num_samples = 0 | |||||
| # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2) | |||||
| for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), | |||||
| data2.create_dict_iterator(num_epochs=1, output_numpy=True)): | |||||
| np.testing.assert_array_equal(item1['image'], item2['image']) | |||||
| num_samples += 1 | |||||
| assert num_samples == 2 | |||||
| # Restore configuration num_parallel_workers | |||||
| ds.config.set_seed(original_seed) | |||||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||||
| def test_serdes_celeba_dataset(remove_json_files=True): | |||||
| """ | |||||
| Test serdes on Celeba dataset pipeline. | |||||
| """ | |||||
| DATA_DIR = "../data/dataset/testCelebAData/" | |||||
| data1 = ds.CelebADataset(DATA_DIR, decode=True, num_shards=1, shard_id=0) | |||||
| # define map operations | |||||
| data1 = data1.repeat(2) | |||||
| center_crop = vision.CenterCrop((80, 80)) | |||||
| pad_op = vision.Pad(20, fill_value=(20, 20, 20)) | |||||
| data1 = data1.map(operations=[center_crop, pad_op], input_columns=["image"], num_parallel_workers=8) | |||||
| data2 = util_check_serialize_deserialize_file(data1, "celeba_dataset_pipeline", remove_json_files) | |||||
| num_samples = 0 | |||||
| # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2) | |||||
| for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), | |||||
| data2.create_dict_iterator(num_epochs=1, output_numpy=True)): | |||||
| np.testing.assert_array_equal(item1['image'], item2['image']) | |||||
| num_samples += 1 | |||||
| assert num_samples == 8 | |||||
| def test_serdes_csv_dataset(remove_json_files=True): | |||||
| """ | |||||
| Test serdes on Csvdataset pipeline. | |||||
| """ | |||||
| DATA_DIR = "../data/dataset/testCSV/1.csv" | |||||
| data1 = ds.CSVDataset( | |||||
| DATA_DIR, | |||||
| column_defaults=["1", "2", "3", "4"], | |||||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||||
| shuffle=False) | |||||
| columns = ["col1", "col4", "col2"] | |||||
| data1 = data1.project(columns=columns) | |||||
| data2 = util_check_serialize_deserialize_file(data1, "csv_dataset_pipeline", remove_json_files) | |||||
| num_samples = 0 | |||||
| # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2) | |||||
| for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), | |||||
| data2.create_dict_iterator(num_epochs=1, output_numpy=True)): | |||||
| np.testing.assert_array_equal(item1['col1'], item2['col1']) | |||||
| np.testing.assert_array_equal(item1['col2'], item2['col2']) | |||||
| np.testing.assert_array_equal(item1['col4'], item2['col4']) | |||||
| num_samples += 1 | |||||
| assert num_samples == 3 | |||||
| def test_serdes_voc_dataset(remove_json_files=True): | |||||
| """ | |||||
| Test serdes on VOC dataset pipeline. | |||||
| """ | |||||
| data_dir = "../data/dataset/testVOC2012" | |||||
| original_seed = config_get_set_seed(1) | |||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||||
| # define map operations | |||||
| random_color_adjust_op = vision.RandomColorAdjust(brightness=(0.5, 0.5)) | |||||
| random_rotation_op = vision.RandomRotation((0, 90), expand=True, resample=Inter.BILINEAR, center=(50, 50), | |||||
| fill_value=150) | |||||
| data1 = ds.VOCDataset(data_dir, task="Detection", usage="train", decode=True) | |||||
| data1 = data1.map(operations=random_color_adjust_op, input_columns=["image"]) | |||||
| data1 = data1.map(operations=random_rotation_op, input_columns=["image"]) | |||||
| data1 = data1.skip(2) | |||||
| data2 = util_check_serialize_deserialize_file(data1, "voc_dataset_pipeline", remove_json_files) | |||||
| num_samples = 0 | |||||
| # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2) | |||||
| for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), | |||||
| data2.create_dict_iterator(num_epochs=1, output_numpy=True)): | |||||
| np.testing.assert_array_equal(item1['image'], item2['image']) | |||||
| num_samples += 1 | |||||
| assert num_samples == 7 | |||||
| # Restore configuration num_parallel_workers | |||||
| ds.config.set_seed(original_seed) | |||||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||||
| def test_serdes_to_device(remove_json_files=True): | |||||
| """ | |||||
| Test serdes on VOC dataset pipeline. | |||||
| """ | |||||
| data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||||
| schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||||
| data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False) | |||||
| data1 = data1.to_device() | |||||
| util_check_serialize_deserialize_file(data1, "transfer_dataset_pipeline", remove_json_files) | |||||
| def test_serdes_exception(): | |||||
| """ | |||||
| Test exception case in serdes | |||||
| """ | |||||
| data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||||
| schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||||
| data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False) | |||||
| data1 = data1.filter(input_columns=["image", "label"], predicate=lambda data: data < 11, num_parallel_workers=4) | |||||
| data1_json = ds.serialize(data1) | |||||
| with pytest.raises(RuntimeError) as msg: | |||||
| ds.deserialize(input_dict=data1_json) | |||||
| assert "Filter is not yet supported by ds.engine.deserialize" in str(msg) | |||||
| def util_check_serialize_deserialize_file(data_orig, filename, remove_json_files): | |||||
| """ | |||||
| Utility function for testing serdes files. It is to check if a json file is indeed created with correct name | |||||
| after serializing and if it remains the same after repeatly saving and loading. | |||||
| :param data_orig: original data pipeline to be serialized | |||||
| :param filename: filename to be saved as json format | |||||
| :param remove_json_files: whether to remove the json file after testing | |||||
| :return: The data pipeline after serializing and deserializing using the original pipeline | |||||
| """ | |||||
| file1 = filename + ".json" | |||||
| file2 = filename + "_1.json" | |||||
| ds.serialize(data_orig, file1) | |||||
| assert validate_jsonfile(file1) is True | |||||
| assert validate_jsonfile("wrong_name.json") is False | |||||
| data_changed = ds.deserialize(json_filepath=file1) | |||||
| ds.serialize(data_changed, file2) | |||||
| assert validate_jsonfile(file2) is True | |||||
| assert filecmp.cmp(file1, file2) | |||||
| # Remove the generated json file | |||||
| if remove_json_files: | |||||
| delete_json_files() | |||||
| return data_changed | |||||
| def validate_jsonfile(filepath): | def validate_jsonfile(filepath): | ||||
| try: | try: | ||||
| file_exist = os.path.exists(filepath) | file_exist = os.path.exists(filepath) | ||||
| @@ -276,7 +451,12 @@ def skip_test_minddataset(add_and_remove_cv_file): | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_imagefolder() | |||||
| test_zip_dataset() | |||||
| test_mnist_dataset() | |||||
| test_random_crop() | |||||
| test_serdes_imagefolder_dataset() | |||||
| test_serdes_mnist_dataset() | |||||
| test_serdes_cifar10_dataset() | |||||
| test_serdes_celeba_dataset() | |||||
| test_serdes_csv_dataset() | |||||
| test_serdes_voc_dataset() | |||||
| test_serdes_zip_dataset() | |||||
| test_serdes_random_crop() | |||||
| test_serdes_exception() | |||||