diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index 8d0e4046f6..c7fb9955a9 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -71,7 +71,8 @@ static std::unordered_map g_parse_op_func_ = {{kStorage, &D {kCifar100, &DEPipeline::ParseCifar100Op}, {kCelebA, &DEPipeline::ParseCelebAOp}, {kRandomData, &DEPipeline::ParseRandomDataOp}, - {kTextFile, &DEPipeline::ParseTextFileOp}}; + {kTextFile, &DEPipeline::ParseTextFileOp}, + {kBuildVocab, &DEPipeline::ParseBuildVocabOp}}; DEPipeline::DEPipeline() : iterator_(nullptr) { try { @@ -1235,5 +1236,36 @@ Status DEPipeline::ParsePadInfo(py::handle value, PadInfo *pad_info) { } return Status::OK(); } +Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr *ptr) { + std::shared_ptr builder = std::make_shared(); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "freq_range") { + py::tuple tp = py::reinterpret_borrow(value); + if (!tp[0].is_none()) (void)builder->SetMinFreq(py::reinterpret_borrow(tp[0])); + if (!tp[1].is_none()) (void)builder->SetMaxFreq(py::reinterpret_borrow(tp[1])); + } + if (key == "top_k") { + builder->SetTopK(py::reinterpret_borrow(value)); + } + if (key == "columns") { + (void)builder->SetColumnNames(ToStringVector(value)); + } + if (key == "vocab") { + (void)builder->SetVocab(value.cast>()); + } + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *ptr = op; + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.h b/mindspore/ccsrc/dataset/api/de_pipeline.h index 3b5001fa36..40133fc7b7 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.h +++ b/mindspore/ccsrc/dataset/api/de_pipeline.h @@ -63,7 +63,8 @@ enum OpName { kCifar100, kCelebA, kRandomData, - kTextFile + kTextFile, + kBuildVocab }; // The C++ binder class that we expose to the python script. @@ -163,6 +164,8 @@ class DEPipeline { Status ParseTextFileOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr *ptr); + private: // Execution tree that links the dataset operators. std::shared_ptr tree_; diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 57fbaea027..00c2ff594b 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -514,6 +514,7 @@ void bindInfoObjects(py::module *m) { void bindVocabObjects(py::module *m) { (void)py::class_>(*m, "Vocab") + .def(py::init<>()) .def_static("from_list", [](const py::list &words) { std::shared_ptr v; @@ -624,6 +625,7 @@ PYBIND11_MODULE(_c_dataengine, m) { .value("CIFAR10", OpName::kCifar10) .value("CIFAR100", OpName::kCifar100) .value("RANDOMDATA", OpName::kRandomData) + .value("BUILDVOCAB", OpName::kBuildVocab) .value("CELEBA", OpName::kCelebA) .value("TEXTFILE", OpName::kTextFile); diff --git a/mindspore/ccsrc/dataset/core/client.h b/mindspore/ccsrc/dataset/core/client.h index aa5e85f7de..d4458a70b8 100644 --- a/mindspore/ccsrc/dataset/core/client.h +++ b/mindspore/ccsrc/dataset/core/client.h @@ -27,6 +27,7 @@ #include "dataset/engine/dataset_iterator.h" #include "dataset/engine/datasetops/barrier_op.h" #include "dataset/engine/datasetops/batch_op.h" +#include "dataset/engine/datasetops/build_vocab_op.h" #include "dataset/engine/datasetops/dataset_op.h" #include "dataset/engine/datasetops/device_queue_op.h" #include "dataset/engine/datasetops/map_op.h" diff --git a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt index 70065df5f4..63265a2225 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt @@ -19,5 +19,6 @@ add_library(engine-datasetops OBJECT zip_op.cc concat_op.cc filter_op.cc + build_vocab_op.cc ) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.cc new file mode 100644 index 0000000000..2825950f25 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.cc @@ -0,0 +1,179 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "dataset/engine/datasetops/build_vocab_op.h" + +#include +#include +#include +#include +#include +#include "dataset/core/config_manager.h" + +namespace mindspore { +namespace dataset { + +BuildVocabOp::BuildVocabOp(std::shared_ptr vocab, std::vector col_names, + std::pair freq_r, int64_t top_k, int32_t num_workers, int32_t op_conn_size) + : ParallelOp(num_workers, op_conn_size), + interval_(op_conn_size * num_workers), + vocab_(vocab), + col_names_(col_names), + freq_range_(freq_r), + top_k_(top_k) { + // init two queues for thread sync + distributor_queue_ = std::make_unique>(num_workers * op_conn_size); + collector_queue_ = + std::make_unique>>>(num_workers * op_conn_size); +} + +Status BuildVocabOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + TensorRow new_row; + RETURN_IF_NOT_OK(distributor_queue_->PopFront(&new_row)); + std::unique_ptr> wrkr_map = + std::make_unique>(); + int32_t row_cnt = 0; + while (!new_row.empty()) { + for (int32_t col : col_ids_) { + CHECK_FAIL_RETURN_UNEXPECTED(!new_row[col]->type().IsNumeric(), "from_dataset only works on string columns"); + for (auto itr = new_row[col]->begin(); itr != new_row[col]->end(); itr++) { + (*wrkr_map)[std::string(*itr)] += 1; + } + } + row_cnt++; // row is processed by this point + if ((row_cnt % interval_ == 0) && ((row_cnt / interval_) % num_workers_ == worker_id) && (!wrkr_map->empty())) { + RETURN_IF_NOT_OK(collector_queue_->Add(std::move(wrkr_map))); + wrkr_map = std::make_unique>(); + } + RETURN_IF_NOT_OK(distributor_queue_->PopFront(&new_row)); + } + // clean up + if (!wrkr_map->empty()) { + RETURN_IF_NOT_OK(collector_queue_->Add(std::move(wrkr_map))); + } + // empty map as quit signal + RETURN_IF_NOT_OK(collector_queue_->Add(std::make_unique>())); + return Status::OK(); +} + +Status BuildVocabOp::operator()() { + // launch the collector thread + RETURN_UNEXPECTED_IF_NULL(tree_); + RETURN_IF_NOT_OK(distributor_queue_->Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(collector_queue_->Register(tree_->AllTasks())); + // launch worker threads and collector thread + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&BuildVocabOp::WorkerEntry, this, std::placeholders::_1))); + RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("collector", std::bind(&BuildVocabOp::CollectorThread, this))); + TaskManager::FindMe()->Post(); + child_iterator_ = std::make_unique(this, 0, 0); + TensorRow new_row; + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); + RETURN_IF_NOT_OK(AssignColMapFromChild()); + if (!col_names_.empty()) { + col_ids_.reserve(col_names_.size()); + for (std::string col : col_names_) { + auto itr = column_name_id_map_.find(col); + CHECK_FAIL_RETURN_UNEXPECTED(itr != column_name_id_map_.end(), col + " column doesn't exist"); + col_ids_.push_back(itr->second); + } + } else { + col_ids_.reserve(column_name_id_map_.size()); + for (const auto &p : column_name_id_map_) { + col_ids_.push_back(p.second); + } + } + bool eoe_warning = false; // give out warning if receive more than 1 eoe + while (child_iterator_->eof_handled() == false) { + while (new_row.empty() == false) { + RETURN_IF_NOT_OK(distributor_queue_->EmplaceBack(new_row)); + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); + } + CHECK_FAIL_RETURN_UNEXPECTED(!eoe_warning, "no op should be after from_dataset (repeat detected)"); + eoe_warning = true; + } + + // tell all workers to quit + for (int32_t wrkr_id = 0; wrkr_id < num_workers_; wrkr_id++) { + RETURN_IF_NOT_OK(distributor_queue_->EmplaceBack(TensorRow())); + } + return Status::OK(); +} + +Status BuildVocabOp::CollectorThread() { + TaskManager::FindMe()->Post(); + int32_t num_quited_worker = 0; + std::unique_ptr> wrkr_map; + while (num_quited_worker != num_workers_) { + RETURN_IF_NOT_OK(collector_queue_->PopFront(&wrkr_map)); + RETURN_UNEXPECTED_IF_NULL(wrkr_map); + if (!wrkr_map->empty()) { + for (const auto &wd : *wrkr_map) word_cnt_[wd.first] += wd.second; + } else { + ++num_quited_worker; + } + } // all frequencies are obtained + CHECK_FAIL_RETURN_UNEXPECTED(!word_cnt_.empty(), "word_cnt is empty"); + std::vector words; + // make sure enough is reserved + words.reserve(wrkr_map->size()); + + for (auto it = word_cnt_.begin(); it != word_cnt_.end();) { + if (it->second >= freq_range_.first && it->second <= freq_range_.second) { + words.push_back(it->first); + it++; + } else { + it = word_cnt_.erase(it); + } + } + int64_t num_words = std::min(static_cast(words.size()), top_k_); + // this would take the top-k most frequent words + std::partial_sort(words.begin(), words.begin() + num_words, words.end(), + [this](const std::string &w1, const std::string &w2) { + int64_t f1 = word_cnt_[w1], f2 = word_cnt_[w2]; + return f1 == f2 ? w1 < w2 : f1 > f2; + }); + for (int64_t i = 0; i < num_words; i++) { + vocab_->append_word(words[i]); + } + RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique(0, DataBuffer::kDeBFlagEOE))); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique(0, DataBuffer::kDeBFlagEOF))); + // then use std::nth_element to partial sort + return Status::OK(); +} + +Status BuildVocabOp::Builder::Build(std::shared_ptr *op) { + CHECK_FAIL_RETURN_UNEXPECTED(builder_num_workers_ > 0, "builder num_workers need to be greater than 0"); + CHECK_FAIL_RETURN_UNEXPECTED(builder_top_k_ > 0, "top_k needs to be positive number"); + CHECK_FAIL_RETURN_UNEXPECTED(builder_max_freq_ >= builder_min_freq_ && builder_min_freq_ >= 0, + "frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)"); + (*op) = std::make_shared(builder_vocab_, builder_col_names_, + std::make_pair(builder_min_freq_, builder_max_freq_), builder_top_k_, + builder_num_workers_, builder_connector_size_); + return Status::OK(); +} + +BuildVocabOp::Builder::Builder() + : builder_top_k_(std::numeric_limits::max()), + builder_min_freq_(0), + builder_max_freq_(std::numeric_limits::max()) { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_num_workers_ = cfg->num_parallel_workers(); + builder_connector_size_ = cfg->op_connector_size(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.h b/mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.h new file mode 100644 index 0000000000..2ebc5a49c8 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.h @@ -0,0 +1,153 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_ +#define DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_ + +#include +#include +#include +#include +#include + +#include "dataset/core/tensor.h" +#include "dataset/engine/dataset_iterator.h" +#include "dataset/engine/datasetops/parallel_op.h" +#include "dataset/text/vocab.h" +#include "dataset/util/queue.h" +#include "dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class BuildVocabOp : public ParallelOp { + public: + class Builder { + public: + Builder(); + + // Destructor. + ~Builder() = default; + + // Setter method + // @param int32_t size + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t size) { + builder_connector_size_ = size; + return *this; + } + + // Setter method + // @param int32_t num_workers + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method + // @param int64_t top_k + // @return Builder setter method returns reference to the builder. + Builder &SetTopK(int64_t top_k) { + builder_top_k_ = top_k; + return *this; + } + + // Setter method + // @param int64_t min_freq + // @return Builder setter method returns reference to the builder. + Builder &SetMinFreq(int64_t min_freq) { + builder_min_freq_ = min_freq; + return *this; + } + + // Setter method + // @param int64_t max_freq + // @return Builder setter method returns reference to the builder. + Builder &SetMaxFreq(int64_t max_freq) { + builder_max_freq_ = max_freq; + return *this; + } + + // set columns names + // @param const std::vector & col_names - name of columns to get words + // @return Builder & reference to builder class object + Builder &SetColumnNames(const std::vector &col_names) { + builder_col_names_ = col_names; + return *this; + } + + // set vocab object + Builder &SetVocab(std::shared_ptr vocab) { + builder_vocab_ = vocab; + return *this; + } + + // The builder "build" method creates the final object. + // @param std::shared_ptr *op - DatasetOp + // @return - The error code return + Status Build(std::shared_ptr *op); + + private: + int32_t builder_num_workers_; + int32_t builder_connector_size_; + int64_t builder_min_freq_; + int64_t builder_max_freq_; + std::vector builder_col_names_; + std::shared_ptr builder_vocab_; + int64_t builder_top_k_; + }; + + BuildVocabOp(std::shared_ptr vocab, std::vector col_names, std::pair freq_range, + int64_t top_k, int32_t num_workers, int32_t op_connector_size); + + Status WorkerEntry(int32_t worker_id) override; + + // collect the work product from each worker + Status CollectorThread(); + + Status EofReceived(int32_t) override { return Status::OK(); } + + Status EoeReceived(int32_t) override { return Status::OK(); } + + Status operator()() override; + + // Getter + // @return the number of workers + int32_t num_producers() const override { return 1; } + + // Getter + // @return the number of threads consuming from the previous Connector + int32_t num_consumers() const override { return 1; } + + Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildVocabOp"); } + + private: + const int32_t interval_; + std::shared_ptr vocab_; + std::vector col_names_; + std::vector col_ids_; + // pair = {min_f, max_f} + // make sure that 0<= min_f < max_f <= int32_max in the builder + std::pair freq_range_; + + int64_t top_k_; // every thing means top_k_ == int32_max + std::unique_ptr child_iterator_; // child iterator for fetching TensorRows 1 by 1 + std::unique_ptr> distributor_queue_; // master thread assigns each worker TensorRow via this + std::unique_ptr>>> collector_queue_; + std::unordered_map word_cnt_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc index bd7da566b6..298edc8347 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc @@ -364,9 +364,7 @@ Status ImageFolderOp::startAsyncWalk() { } Status ImageFolderOp::LaunchThreadsAndInitOp() { - if (tree_ == nullptr) { - RETURN_STATUS_UNEXPECTED("tree_ not set"); - } + RETURN_UNEXPECTED_IF_NULL(tree_); // Registers QueueList and individual Queues for interrupt services RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(folder_name_queue_->Register(tree_->AllTasks())); diff --git a/mindspore/ccsrc/dataset/text/vocab.cc b/mindspore/ccsrc/dataset/text/vocab.cc index 893336c62a..5c1892a0f9 100644 --- a/mindspore/ccsrc/dataset/text/vocab.cc +++ b/mindspore/ccsrc/dataset/text/vocab.cc @@ -21,19 +21,23 @@ namespace mindspore { namespace dataset { -Vocab::Vocab(std::unordered_map word2id) { - word2id_ = std::move(word2id); - id2word_.resize(word2id_.size()); - for (auto p : word2id_) { - id2word_[p.second - kSpecialTokens::num_tokens] = p.first; - } -} +Vocab::Vocab(std::unordered_map word2id) { word2id_ = std::move(word2id); } WordIdType Vocab::Lookup(const WordType &word, WordIdType default_id) const { auto itr = word2id_.find(word); return itr == word2id_.end() ? default_id : itr->second; } -WordType Vocab::Lookup(WordIdType id) const { + +WordType Vocab::Lookup(WordIdType id) { + // this operation is most likely only done with since reverse lookup is only needed when training is done + // hence, the worst case of inserting while keep looking up isn't likely to happen + if (id2word_.size() != word2id_.size() && (id - kSpecialTokens::num_tokens >= id2word_.size())) { + id2word_.clear(); + id2word_.reserve(word2id_.size()); + for (auto p : word2id_) { + id2word_[p.second - kSpecialTokens::num_tokens] = p.first; + } + } if (id < kSpecialTokens::num_tokens) { return reserved_token_str_[id]; } else if (id - kSpecialTokens::num_tokens >= id2word_.size()) { @@ -97,5 +101,11 @@ Status Vocab::BuildFromPyDict(const py::dict &words, std::shared_ptr *voc return Status::OK(); } const std::vector Vocab::reserved_token_str_ = {"", ""}; + +void Vocab::append_word(const std::string &word) { + if (word2id_.find(word) == word2id_.end()) { + word2id_[word] = word2id_.size() + kSpecialTokens::num_tokens; + } +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/vocab.h b/mindspore/ccsrc/dataset/text/vocab.h index 03eaf0f149..29557cdecb 100644 --- a/mindspore/ccsrc/dataset/text/vocab.h +++ b/mindspore/ccsrc/dataset/text/vocab.h @@ -65,12 +65,19 @@ class Vocab { // reverse lookup, lookup the word based on its id // @param WordIdType id - word id to lookup to // @return WordType the word - WordType Lookup(WordIdType id) const; + WordType Lookup(WordIdType id); // constructor, shouldn't be called directly, can't be private due to std::make_unique() // @param std::unordered_map map - sanitized word2id map explicit Vocab(std::unordered_map map); + Vocab() = default; + + // add one word to vocab, increment it's index automatically + // @param std::string & word - word to be added will skip if word already exists + void append_word(const std::string &word); + + // destructor ~Vocab() = default; // enum type that holds all special tokens, add more if needed diff --git a/mindspore/dataset/engine/__init__.py b/mindspore/dataset/engine/__init__.py index adc78f48b8..044b639dfa 100644 --- a/mindspore/dataset/engine/__init__.py +++ b/mindspore/dataset/engine/__init__.py @@ -28,10 +28,10 @@ from .serializer_deserializer import serialize, deserialize, show, compare from .samplers import * from ..core.configuration import config, ConfigurationManager - __all__ = ["config", "ConfigurationManager", "zip", "ImageFolderDatasetV2", "MnistDataset", "MindDataset", "GeneratorDataset", "TFRecordDataset", "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", - "VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", + "VOCDataset", "CocoDataset", "TextFileDataset", "BuildVocabDataset", "Schema", "Schema", + "DistributedSampler", "PKSampler", "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"] diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index d667eccca5..270349984a 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -42,8 +42,8 @@ from .iterators import DictIterator, TupleIterator from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ check_rename, check_numpyslicesdataset, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ - check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset,\ - check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat,\ + check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ + check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ check_split from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist @@ -824,6 +824,29 @@ class Dataset: return ProjectDataset(self, columns) + def build_vocab(self, vocab, columns, freq_range, top_k): + """ + Build a vocab from a dataset. This would collect all the unique words in a dataset and return a vocab + which contains top_k most frequent words (if top_k is specified) + This function is not meant to be called directly by user. To build vocab, please use the function + text.Vocab.from_dataset() + + Args: + vocab(Vocab): vocab object + columns(str or list, optional): column names to get words from. It can be a list of column names. + (Default is None where all columns will be used. If any column isn't string type, will return error) + freq_range(tuple, optional): A tuple of integers (min_frequency, max_frequency). Words within the frequency + range would be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency + can be None, which corresponds to 0/total_words separately (default is None, all words are included) + top_k(int, optional): top_k > 0. Number of words to be built into vocab. top_k most frequent words are + taken. top_k is taken after freq_range. If not enough top_k, all words will be taken. (default is None + all words are included) + + Returns: + BuildVocabDataset + """ + return BuildVocabDataset(self, vocab, columns, freq_range, top_k) + def apply(self, apply_func): """ Apply a function in this dataset. @@ -1483,6 +1506,7 @@ class BatchDataset(DatasetOp): for input_dataset in dataset.input: BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) + class BatchInfo(CBatchInfo): """ The information object associates with the current batch of tensors. @@ -1506,6 +1530,7 @@ class BatchInfo(CBatchInfo): """ return + class BlockReleasePair: """ The blocking condition class used by SyncWaitDataset. @@ -1514,6 +1539,7 @@ class BlockReleasePair: init_release_rows (int): Number of lines to allow through the pipeline. callback (function): The callback funciton that will be called when release is called. """ + def __init__(self, init_release_rows, callback=None): if isinstance(init_release_rows, int) and init_release_rows <= 0: raise ValueError("release_rows need to be greater than 0.") @@ -1696,6 +1722,7 @@ class _PythonCallable: """ Internal python function wrapper for multiprocessing pyfunc. """ + def __init__(self, py_callable, idx, pool=None): # Original python callable from user. self.py_callable = py_callable @@ -2593,7 +2620,7 @@ class MindDataset(SourceDataset): if sampler is not None: if isinstance(sampler, samplers.SubsetRandomSampler) is False and \ - isinstance(sampler, samplers.PKSampler) is False: + isinstance(sampler, samplers.PKSampler) is False: raise ValueError("the sampler is not supported yet.") # sampler exclusive @@ -2859,6 +2886,7 @@ class _GeneratorWorker(multiprocessing.Process): """ Worker process for multiprocess Generator. """ + def __init__(self, dataset, eoe): self.idx_queue = multiprocessing.Queue(16) self.res_queue = multiprocessing.Queue(16) @@ -3686,6 +3714,7 @@ class RandomDataset(SourceDataset): def is_sharded(self): return False + class Schema: """ Class to represent a schema of dataset. @@ -4383,6 +4412,7 @@ class _NumpySlicesDataset: """ Mainly for dealing with several kinds of format of python data, and return one row each time. """ + def __init__(self, data, column_list=None): self.column_list = None # Convert dict data into tuple @@ -4525,6 +4555,7 @@ class NumpySlicesDataset(GeneratorDataset): >>> df = pd.read_csv("file.csv") >>> dataset4 = ds.NumpySlicesDataset(dict(df), shuffle=False) """ + @check_numpyslicesdataset def __init__(self, data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None): @@ -4532,3 +4563,61 @@ class NumpySlicesDataset(GeneratorDataset): super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples, num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler, num_shards=num_shards, shard_id=shard_id) + + +class BuildVocabDataset(DatasetOp): + """ + Build a vocab from a dataset. This would collect all the unique words in a dataset and return a vocab + which contains top_k most frequent words (if top_k is specified) + This function is not meant to be called directly by user. To build vocab, please use the function + text.Vocab.from_dataset() + + Args: + vocab(Vocab): vocab object + columns(str or list, optional): column names to get words from. It can be a list of column names. + (Default is None where all columns will be used. If any column isn't string type, will return error) + freq_range(tuple, optional): A tuple of integers (min_frequency, max_frequency). Words within the frequency + range would be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency + can be None, which corresponds to 0/total_words separately (default is None, all words are included) + top_k(int, optional): top_k > 0. Number of words to be built into vocab. top_k most frequent words are + taken. top_k is taken after freq_range. If not enough top_k, all words will be taken. (default is None + all words are included) + + Returns: + BuildVocabDataset + """ + + def __init__(self, input_dataset, vocab, columns, freq_range, top_k, prefetch_size=None): + super().__init__() + self.columns = columns + self.input.append(input_dataset) + self.prefetch_size = prefetch_size + self.vocab = vocab + self.freq_range = freq_range + self.top_k = top_k + input_dataset.output.append(self) + + def get_args(self): + args = super().get_args() + args["columns"] = self.columns + args["vocab"] = self.vocab + args["freq_range"] = self.freq_range + args["prefetch_size"] = self.prefetch_size + args["top_k"] = self.top_k + return args + + def __deepcopy__(self, memodict): + if id(self) in memodict: + return memodict[id(self)] + cls = self.__class__ + new_op = cls.__new__(cls) + memodict[id(self)] = new_op + new_op.input = copy.deepcopy(self.input, memodict) + new_op.columns = copy.deepcopy(self.columns, memodict) + new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) + new_op.prefetch_size = copy.deepcopy(self.prefetch_size, memodict) + new_op.output = copy.deepcopy(self.output, memodict) + new_op.freq_range = copy.deepcopy(self.freq_range, memodict) + new_op.top_k = copy.deepcopy(self.top_k, memodict) + new_op.vocab = self.vocab + return new_op diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index ad8725b302..d621f76256 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -177,6 +177,8 @@ class Iterator: op_type = OpName.RANDOMDATA elif isinstance(dataset, de.TextFileDataset): op_type = OpName.TEXTFILE + elif isinstance(dataset, de.BuildVocabDataset): + op_type = OpName.BUILDVOCAB else: raise ValueError("Unsupported DatasetOp") diff --git a/mindspore/dataset/text/utils.py b/mindspore/dataset/text/utils.py index 6b77edec9b..b590937d7d 100644 --- a/mindspore/dataset/text/utils.py +++ b/mindspore/dataset/text/utils.py @@ -16,20 +16,43 @@ Some basic function for nlp """ from enum import IntEnum +import copy import numpy as np import mindspore._c_dataengine as cde -from .validators import check_from_file, check_from_list, check_from_dict +from .validators import check_from_file, check_from_list, check_from_dict, check_from_dataset class Vocab(cde.Vocab): """ Vocab object that is used for lookup word - Args: """ - def __init__(self): - pass + @classmethod + @check_from_dataset + def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None): + """ + Build a vocab from a dataset. This would collect all the unique words in a dataset and return a vocab + which contains top_k most frequent words (if top_k is specified) + Args: + dataset(Dataset): dataset to build vocab from. + columns(str or list, optional): column names to get words from. It can be a list of column names. + (Default is None where all columns will be used. If any column isn't string type, will return error) + freq_range(tuple, optional): A tuple of integers (min_frequency, max_frequency). Words within the frequency + range would be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency + can be None, which corresponds to 0/total_words separately (default is None, all words are included) + top_k(int, optional): top_k > 0. Number of words to be built into vocab. top_k most frequent words are + taken. top_k is taken after freq_range. If not enough top_k, all words will be taken. (default is None + all words are included) + return: + text.Vocab: vocab object built from dataset. + """ + vocab = Vocab() + root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k) + for d in root.create_dict_iterator(): + if d is not None: + raise ValueError("from_dataset should receive data other than None") + return vocab @classmethod @check_from_list diff --git a/mindspore/dataset/text/validators.py b/mindspore/dataset/text/validators.py index da6f7dc2b1..25d0aaf241 100644 --- a/mindspore/dataset/text/validators.py +++ b/mindspore/dataset/text/validators.py @@ -186,6 +186,62 @@ def check_jieba_add_dict(method): return new_method +def check_from_dataset(method): + """A wrapper that wrap a parameter checker to the original function(crop operation).""" + + # def from_dataset(cls, dataset, columns, freq_range=None, top_k=None): + @wraps(method) + def new_method(self, *args, **kwargs): + dataset, columns, freq_range, top_k = (list(args) + 4 * [None])[:4] + if "dataset" in kwargs: + dataset = kwargs.get("dataset") + if "columns" in kwargs: + columns = kwargs.get("columns") + if "freq_range" in kwargs: + freq_range = kwargs.get("freq_range") + if "top_k" in kwargs: + top_k = kwargs.get("top_k") + + if columns is None: + columns = [] + + if not isinstance(columns, list): + columns = [columns] + + for column in columns: + if not isinstance(column, str): + raise ValueError("columns need to be a list of strings") + + if freq_range is None: + freq_range = (None, None) + + if not isinstance(freq_range, tuple) or len(freq_range) != 2: + raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None") + + for num in freq_range: + if num is not None and (not isinstance(num, int)): + raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None") + + if isinstance(freq_range[0], int) and isinstance(freq_range[1], int): + if freq_range[0] > freq_range[1] or freq_range[0] < 0: + raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)") + + if top_k is not None and (not isinstance(top_k, int)): + raise ValueError("top_k needs to be a positive integer") + + if isinstance(top_k, int) and top_k <= 0: + raise ValueError("top_k needs to be a positive integer") + + kwargs["dataset"] = dataset + kwargs["columns"] = columns + kwargs["freq_range"] = freq_range + kwargs["top_k"] = top_k + + return method(self, **kwargs) + + return new_method + + def check_ngram(method): """A wrapper that wrap a parameter checker to the original function(crop operation).""" diff --git a/tests/ut/python/dataset/test_from_dataset.py b/tests/ut/python/dataset/test_from_dataset.py new file mode 100644 index 0000000000..719f21fd7d --- /dev/null +++ b/tests/ut/python/dataset/test_from_dataset.py @@ -0,0 +1,112 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing from_dataset in mindspore.dataset +""" +import numpy as np +import mindspore.dataset as ds +import mindspore.dataset.text as text + + +def test_demo_basic_from_dataset(): + """ this is a tutorial on how from_dataset should be used in a normal use case""" + data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False) + vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None) + data = data.map(input_columns=["text"], operations=text.Lookup(vocab)) + res = [] + for d in data.create_dict_iterator(): + res.append(d["text"].item()) + assert res == [4, 5, 3, 6, 7, 2] + + +def test_demo_basic_from_dataset_with_tokenizer(): + """ this is a tutorial on how from_dataset should be used in a normal use case with tokenizer""" + data = ds.TextFileDataset("../data/dataset/testTokenizerData/1.txt", shuffle=False) + data = data.map(input_columns=["text"], operations=text.UnicodeCharTokenizer()) + vocab = text.Vocab.from_dataset(data, None, freq_range=None, top_k=None) + data = data.map(input_columns=["text"], operations=text.Lookup(vocab)) + res = [] + for d in data.create_dict_iterator(): + res.append(list(d["text"])) + assert res == [[13, 3, 7, 14, 9, 17, 3, 2, 19, 9, 2, 11, 3, 4, 16, 4, 8, 6, 5], [21, 20, 10, 25, 23, 26], + [24, 22, 10, 12, 8, 6, 7, 4, 18, 15, 5], [2, 2]] + + +def test_from_dataset(): + """ test build vocab with generator dataset """ + + def gen_corpus(): + # key: word, value: number of occurrences, reason for using letters is so their order is apparent + corpus = {"Z": 4, "Y": 4, "X": 4, "W": 3, "U": 3, "V": 2, "T": 1} + for k, v in corpus.items(): + yield (np.array([k] * v, dtype='S'),) + + def test_config(freq_range, top_k): + corpus_dataset = ds.GeneratorDataset(gen_corpus, column_names=["text"]) + vocab = text.Vocab.from_dataset(corpus_dataset, None, freq_range, top_k) + corpus_dataset = corpus_dataset.map(input_columns="text", operations=text.Lookup(vocab)) + res = [] + for d in corpus_dataset.create_dict_iterator(): + res.append(list(d["text"])) + return res + + # take words whose frequency is with in [3,4] order them alphabetically for words with the same frequency + test1_res = test_config(freq_range=(3, 4), top_k=4) + assert test1_res == [[4, 4, 4, 4], [3, 3, 3, 3], [2, 2, 2, 2], [1, 1, 1], [5, 5, 5], [1, 1], [1]], str(test1_res) + + # test words with frequency range [2,inf], only the last word will be filtered out + test2_res = test_config((2, None), None) + assert test2_res == [[4, 4, 4, 4], [3, 3, 3, 3], [2, 2, 2, 2], [6, 6, 6], [5, 5, 5], [7, 7], [1]], str(test2_res) + + # test filter only by top_k + test3_res = test_config(None, 4) + assert test3_res == [[4, 4, 4, 4], [3, 3, 3, 3], [2, 2, 2, 2], [1, 1, 1], [5, 5, 5], [1, 1], [1]], str(test3_res) + + # test filtering out the most frequent + test4_res = test_config((None, 3), 100) + assert test4_res == [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [3, 3, 3], [2, 2, 2], [4, 4], [5]], str(test4_res) + + # test top_k == 1 + test5_res = test_config(None, 1) + assert test5_res == [[1, 1, 1, 1], [1, 1, 1, 1], [2, 2, 2, 2], [1, 1, 1], [1, 1, 1], [1, 1], [1]], str(test5_res) + + # test min_frequency == max_frequency + test6_res = test_config((4, 4), None) + assert test6_res == [[4, 4, 4, 4], [3, 3, 3, 3], [2, 2, 2, 2], [1, 1, 1], [1, 1, 1], [1, 1], [1]], str(test6_res) + + +def test_from_dataset_exceptions(): + """ test various exceptions during that are checked in validator """ + + def test_config(columns, freq_range, top_k, s): + try: + data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False) + vocab = text.Vocab.from_dataset(data, columns, freq_range, top_k) + assert isinstance(vocab.text.Vocab) + except ValueError as e: + assert s in str(e), str(e) + + test_config("text", (), 1, "freq_range needs to be either None or a tuple of 2 integers") + test_config("text", (2, 3), 1.2345, "top_k needs to be a positive integer") + test_config(23, (2, 3), 1.2345, "columns need to be a list of strings") + test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b") + test_config("text", (2, 3), 0, "top_k needs to be a positive integer") + test_config([123], (2, 3), 0, "columns need to be a list of strings") + +if __name__ == '__main__': + test_demo_basic_from_dataset() + test_from_dataset() + test_from_dataset_exceptions() + test_demo_basic_from_dataset_with_tokenizer() diff --git a/tests/ut/python/dataset/test_ngram_op.py b/tests/ut/python/dataset/test_ngram_op.py index 64a5106801..94d00674d7 100644 --- a/tests/ut/python/dataset/test_ngram_op.py +++ b/tests/ut/python/dataset/test_ngram_op.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== """ -Testing NgramOP in DE +Testing Ngram in mindspore.dataset """ import mindspore.dataset as ds import mindspore.dataset.text as nlp