fix complie error more tests address CI complains fix ci adress review comments address review cmtstags/v0.5.0-beta
| @@ -71,7 +71,8 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D | |||||
| {kCifar100, &DEPipeline::ParseCifar100Op}, | {kCifar100, &DEPipeline::ParseCifar100Op}, | ||||
| {kCelebA, &DEPipeline::ParseCelebAOp}, | {kCelebA, &DEPipeline::ParseCelebAOp}, | ||||
| {kRandomData, &DEPipeline::ParseRandomDataOp}, | {kRandomData, &DEPipeline::ParseRandomDataOp}, | ||||
| {kTextFile, &DEPipeline::ParseTextFileOp}}; | |||||
| {kTextFile, &DEPipeline::ParseTextFileOp}, | |||||
| {kBuildVocab, &DEPipeline::ParseBuildVocabOp}}; | |||||
| DEPipeline::DEPipeline() : iterator_(nullptr) { | DEPipeline::DEPipeline() : iterator_(nullptr) { | ||||
| try { | try { | ||||
| @@ -1235,5 +1236,36 @@ Status DEPipeline::ParsePadInfo(py::handle value, PadInfo *pad_info) { | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||||
| std::shared_ptr<BuildVocabOp::Builder> builder = std::make_shared<BuildVocabOp::Builder>(); | |||||
| for (auto arg : args) { | |||||
| std::string key = py::str(arg.first); | |||||
| py::handle value = arg.second; | |||||
| if (!value.is_none()) { | |||||
| if (key == "freq_range") { | |||||
| py::tuple tp = py::reinterpret_borrow<py::tuple>(value); | |||||
| if (!tp[0].is_none()) (void)builder->SetMinFreq(py::reinterpret_borrow<py::int_>(tp[0])); | |||||
| if (!tp[1].is_none()) (void)builder->SetMaxFreq(py::reinterpret_borrow<py::int_>(tp[1])); | |||||
| } | |||||
| if (key == "top_k") { | |||||
| builder->SetTopK(py::reinterpret_borrow<py::int_>(value)); | |||||
| } | |||||
| if (key == "columns") { | |||||
| (void)builder->SetColumnNames(ToStringVector(value)); | |||||
| } | |||||
| if (key == "vocab") { | |||||
| (void)builder->SetVocab(value.cast<std::shared_ptr<Vocab>>()); | |||||
| } | |||||
| if (key == "num_parallel_workers") { | |||||
| (void)builder->SetNumWorkers(ToInt(value)); | |||||
| } | |||||
| } | |||||
| } | |||||
| std::shared_ptr<BuildVocabOp> op; | |||||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||||
| *ptr = op; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -63,7 +63,8 @@ enum OpName { | |||||
| kCifar100, | kCifar100, | ||||
| kCelebA, | kCelebA, | ||||
| kRandomData, | kRandomData, | ||||
| kTextFile | |||||
| kTextFile, | |||||
| kBuildVocab | |||||
| }; | }; | ||||
| // The C++ binder class that we expose to the python script. | // The C++ binder class that we expose to the python script. | ||||
| @@ -163,6 +164,8 @@ class DEPipeline { | |||||
| Status ParseTextFileOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | Status ParseTextFileOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | ||||
| Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||||
| private: | private: | ||||
| // Execution tree that links the dataset operators. | // Execution tree that links the dataset operators. | ||||
| std::shared_ptr<ExecutionTree> tree_; | std::shared_ptr<ExecutionTree> tree_; | ||||
| @@ -514,6 +514,7 @@ void bindInfoObjects(py::module *m) { | |||||
| void bindVocabObjects(py::module *m) { | void bindVocabObjects(py::module *m) { | ||||
| (void)py::class_<Vocab, std::shared_ptr<Vocab>>(*m, "Vocab") | (void)py::class_<Vocab, std::shared_ptr<Vocab>>(*m, "Vocab") | ||||
| .def(py::init<>()) | |||||
| .def_static("from_list", | .def_static("from_list", | ||||
| [](const py::list &words) { | [](const py::list &words) { | ||||
| std::shared_ptr<Vocab> v; | std::shared_ptr<Vocab> v; | ||||
| @@ -624,6 +625,7 @@ PYBIND11_MODULE(_c_dataengine, m) { | |||||
| .value("CIFAR10", OpName::kCifar10) | .value("CIFAR10", OpName::kCifar10) | ||||
| .value("CIFAR100", OpName::kCifar100) | .value("CIFAR100", OpName::kCifar100) | ||||
| .value("RANDOMDATA", OpName::kRandomData) | .value("RANDOMDATA", OpName::kRandomData) | ||||
| .value("BUILDVOCAB", OpName::kBuildVocab) | |||||
| .value("CELEBA", OpName::kCelebA) | .value("CELEBA", OpName::kCelebA) | ||||
| .value("TEXTFILE", OpName::kTextFile); | .value("TEXTFILE", OpName::kTextFile); | ||||
| @@ -27,6 +27,7 @@ | |||||
| #include "dataset/engine/dataset_iterator.h" | #include "dataset/engine/dataset_iterator.h" | ||||
| #include "dataset/engine/datasetops/barrier_op.h" | #include "dataset/engine/datasetops/barrier_op.h" | ||||
| #include "dataset/engine/datasetops/batch_op.h" | #include "dataset/engine/datasetops/batch_op.h" | ||||
| #include "dataset/engine/datasetops/build_vocab_op.h" | |||||
| #include "dataset/engine/datasetops/dataset_op.h" | #include "dataset/engine/datasetops/dataset_op.h" | ||||
| #include "dataset/engine/datasetops/device_queue_op.h" | #include "dataset/engine/datasetops/device_queue_op.h" | ||||
| #include "dataset/engine/datasetops/map_op.h" | #include "dataset/engine/datasetops/map_op.h" | ||||
| @@ -19,5 +19,6 @@ add_library(engine-datasetops OBJECT | |||||
| zip_op.cc | zip_op.cc | ||||
| concat_op.cc | concat_op.cc | ||||
| filter_op.cc | filter_op.cc | ||||
| build_vocab_op.cc | |||||
| ) | ) | ||||
| @@ -0,0 +1,179 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "dataset/engine/datasetops/build_vocab_op.h" | |||||
| #include <algorithm> | |||||
| #include <limits> | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <utility> | |||||
| #include "dataset/core/config_manager.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| BuildVocabOp::BuildVocabOp(std::shared_ptr<Vocab> vocab, std::vector<std::string> col_names, | |||||
| std::pair<int64_t, int64_t> freq_r, int64_t top_k, int32_t num_workers, int32_t op_conn_size) | |||||
| : ParallelOp(num_workers, op_conn_size), | |||||
| interval_(op_conn_size * num_workers), | |||||
| vocab_(vocab), | |||||
| col_names_(col_names), | |||||
| freq_range_(freq_r), | |||||
| top_k_(top_k) { | |||||
| // init two queues for thread sync | |||||
| distributor_queue_ = std::make_unique<Queue<TensorRow>>(num_workers * op_conn_size); | |||||
| collector_queue_ = | |||||
| std::make_unique<Queue<std::unique_ptr<std::unordered_map<std::string, int64_t>>>>(num_workers * op_conn_size); | |||||
| } | |||||
| Status BuildVocabOp::WorkerEntry(int32_t worker_id) { | |||||
| TaskManager::FindMe()->Post(); | |||||
| TensorRow new_row; | |||||
| RETURN_IF_NOT_OK(distributor_queue_->PopFront(&new_row)); | |||||
| std::unique_ptr<std::unordered_map<std::string, int64_t>> wrkr_map = | |||||
| std::make_unique<std::unordered_map<std::string, int64_t>>(); | |||||
| int32_t row_cnt = 0; | |||||
| while (!new_row.empty()) { | |||||
| for (int32_t col : col_ids_) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!new_row[col]->type().IsNumeric(), "from_dataset only works on string columns"); | |||||
| for (auto itr = new_row[col]->begin<std::string_view>(); itr != new_row[col]->end<std::string_view>(); itr++) { | |||||
| (*wrkr_map)[std::string(*itr)] += 1; | |||||
| } | |||||
| } | |||||
| row_cnt++; // row is processed by this point | |||||
| if ((row_cnt % interval_ == 0) && ((row_cnt / interval_) % num_workers_ == worker_id) && (!wrkr_map->empty())) { | |||||
| RETURN_IF_NOT_OK(collector_queue_->Add(std::move(wrkr_map))); | |||||
| wrkr_map = std::make_unique<std::unordered_map<std::string, int64_t>>(); | |||||
| } | |||||
| RETURN_IF_NOT_OK(distributor_queue_->PopFront(&new_row)); | |||||
| } | |||||
| // clean up | |||||
| if (!wrkr_map->empty()) { | |||||
| RETURN_IF_NOT_OK(collector_queue_->Add(std::move(wrkr_map))); | |||||
| } | |||||
| // empty map as quit signal | |||||
| RETURN_IF_NOT_OK(collector_queue_->Add(std::make_unique<std::unordered_map<std::string, int64_t>>())); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status BuildVocabOp::operator()() { | |||||
| // launch the collector thread | |||||
| RETURN_UNEXPECTED_IF_NULL(tree_); | |||||
| RETURN_IF_NOT_OK(distributor_queue_->Register(tree_->AllTasks())); | |||||
| RETURN_IF_NOT_OK(collector_queue_->Register(tree_->AllTasks())); | |||||
| // launch worker threads and collector thread | |||||
| RETURN_IF_NOT_OK( | |||||
| tree_->LaunchWorkers(num_workers_, std::bind(&BuildVocabOp::WorkerEntry, this, std::placeholders::_1))); | |||||
| RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("collector", std::bind(&BuildVocabOp::CollectorThread, this))); | |||||
| TaskManager::FindMe()->Post(); | |||||
| child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0); | |||||
| TensorRow new_row; | |||||
| RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); | |||||
| RETURN_IF_NOT_OK(AssignColMapFromChild()); | |||||
| if (!col_names_.empty()) { | |||||
| col_ids_.reserve(col_names_.size()); | |||||
| for (std::string col : col_names_) { | |||||
| auto itr = column_name_id_map_.find(col); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(itr != column_name_id_map_.end(), col + " column doesn't exist"); | |||||
| col_ids_.push_back(itr->second); | |||||
| } | |||||
| } else { | |||||
| col_ids_.reserve(column_name_id_map_.size()); | |||||
| for (const auto &p : column_name_id_map_) { | |||||
| col_ids_.push_back(p.second); | |||||
| } | |||||
| } | |||||
| bool eoe_warning = false; // give out warning if receive more than 1 eoe | |||||
| while (child_iterator_->eof_handled() == false) { | |||||
| while (new_row.empty() == false) { | |||||
| RETURN_IF_NOT_OK(distributor_queue_->EmplaceBack(new_row)); | |||||
| RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!eoe_warning, "no op should be after from_dataset (repeat detected)"); | |||||
| eoe_warning = true; | |||||
| } | |||||
| // tell all workers to quit | |||||
| for (int32_t wrkr_id = 0; wrkr_id < num_workers_; wrkr_id++) { | |||||
| RETURN_IF_NOT_OK(distributor_queue_->EmplaceBack(TensorRow())); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status BuildVocabOp::CollectorThread() { | |||||
| TaskManager::FindMe()->Post(); | |||||
| int32_t num_quited_worker = 0; | |||||
| std::unique_ptr<std::unordered_map<std::string, int64_t>> wrkr_map; | |||||
| while (num_quited_worker != num_workers_) { | |||||
| RETURN_IF_NOT_OK(collector_queue_->PopFront(&wrkr_map)); | |||||
| RETURN_UNEXPECTED_IF_NULL(wrkr_map); | |||||
| if (!wrkr_map->empty()) { | |||||
| for (const auto &wd : *wrkr_map) word_cnt_[wd.first] += wd.second; | |||||
| } else { | |||||
| ++num_quited_worker; | |||||
| } | |||||
| } // all frequencies are obtained | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!word_cnt_.empty(), "word_cnt is empty"); | |||||
| std::vector<std::string> words; | |||||
| // make sure enough is reserved | |||||
| words.reserve(wrkr_map->size()); | |||||
| for (auto it = word_cnt_.begin(); it != word_cnt_.end();) { | |||||
| if (it->second >= freq_range_.first && it->second <= freq_range_.second) { | |||||
| words.push_back(it->first); | |||||
| it++; | |||||
| } else { | |||||
| it = word_cnt_.erase(it); | |||||
| } | |||||
| } | |||||
| int64_t num_words = std::min(static_cast<int64_t>(words.size()), top_k_); | |||||
| // this would take the top-k most frequent words | |||||
| std::partial_sort(words.begin(), words.begin() + num_words, words.end(), | |||||
| [this](const std::string &w1, const std::string &w2) { | |||||
| int64_t f1 = word_cnt_[w1], f2 = word_cnt_[w2]; | |||||
| return f1 == f2 ? w1 < w2 : f1 > f2; | |||||
| }); | |||||
| for (int64_t i = 0; i < num_words; i++) { | |||||
| vocab_->append_word(words[i]); | |||||
| } | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))); | |||||
| // then use std::nth_element to partial sort | |||||
| return Status::OK(); | |||||
| } | |||||
| Status BuildVocabOp::Builder::Build(std::shared_ptr<BuildVocabOp> *op) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(builder_num_workers_ > 0, "builder num_workers need to be greater than 0"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(builder_top_k_ > 0, "top_k needs to be positive number"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(builder_max_freq_ >= builder_min_freq_ && builder_min_freq_ >= 0, | |||||
| "frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)"); | |||||
| (*op) = std::make_shared<BuildVocabOp>(builder_vocab_, builder_col_names_, | |||||
| std::make_pair(builder_min_freq_, builder_max_freq_), builder_top_k_, | |||||
| builder_num_workers_, builder_connector_size_); | |||||
| return Status::OK(); | |||||
| } | |||||
| BuildVocabOp::Builder::Builder() | |||||
| : builder_top_k_(std::numeric_limits<int64_t>::max()), | |||||
| builder_min_freq_(0), | |||||
| builder_max_freq_(std::numeric_limits<int64_t>::max()) { | |||||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||||
| builder_num_workers_ = cfg->num_parallel_workers(); | |||||
| builder_connector_size_ = cfg->op_connector_size(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,153 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_ | |||||
| #define DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <unordered_map> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include "dataset/core/tensor.h" | |||||
| #include "dataset/engine/dataset_iterator.h" | |||||
| #include "dataset/engine/datasetops/parallel_op.h" | |||||
| #include "dataset/text/vocab.h" | |||||
| #include "dataset/util/queue.h" | |||||
| #include "dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class BuildVocabOp : public ParallelOp { | |||||
| public: | |||||
| class Builder { | |||||
| public: | |||||
| Builder(); | |||||
| // Destructor. | |||||
| ~Builder() = default; | |||||
| // Setter method | |||||
| // @param int32_t size | |||||
| // @return Builder setter method returns reference to the builder. | |||||
| Builder &SetOpConnectorSize(int32_t size) { | |||||
| builder_connector_size_ = size; | |||||
| return *this; | |||||
| } | |||||
| // Setter method | |||||
| // @param int32_t num_workers | |||||
| // @return Builder setter method returns reference to the builder. | |||||
| Builder &SetNumWorkers(int32_t num_workers) { | |||||
| builder_num_workers_ = num_workers; | |||||
| return *this; | |||||
| } | |||||
| // Setter method | |||||
| // @param int64_t top_k | |||||
| // @return Builder setter method returns reference to the builder. | |||||
| Builder &SetTopK(int64_t top_k) { | |||||
| builder_top_k_ = top_k; | |||||
| return *this; | |||||
| } | |||||
| // Setter method | |||||
| // @param int64_t min_freq | |||||
| // @return Builder setter method returns reference to the builder. | |||||
| Builder &SetMinFreq(int64_t min_freq) { | |||||
| builder_min_freq_ = min_freq; | |||||
| return *this; | |||||
| } | |||||
| // Setter method | |||||
| // @param int64_t max_freq | |||||
| // @return Builder setter method returns reference to the builder. | |||||
| Builder &SetMaxFreq(int64_t max_freq) { | |||||
| builder_max_freq_ = max_freq; | |||||
| return *this; | |||||
| } | |||||
| // set columns names | |||||
| // @param const std::vector<std::string> & col_names - name of columns to get words | |||||
| // @return Builder & reference to builder class object | |||||
| Builder &SetColumnNames(const std::vector<std::string> &col_names) { | |||||
| builder_col_names_ = col_names; | |||||
| return *this; | |||||
| } | |||||
| // set vocab object | |||||
| Builder &SetVocab(std::shared_ptr<Vocab> vocab) { | |||||
| builder_vocab_ = vocab; | |||||
| return *this; | |||||
| } | |||||
| // The builder "build" method creates the final object. | |||||
| // @param std::shared_ptr<BuildVocabOp> *op - DatasetOp | |||||
| // @return - The error code return | |||||
| Status Build(std::shared_ptr<BuildVocabOp> *op); | |||||
| private: | |||||
| int32_t builder_num_workers_; | |||||
| int32_t builder_connector_size_; | |||||
| int64_t builder_min_freq_; | |||||
| int64_t builder_max_freq_; | |||||
| std::vector<std::string> builder_col_names_; | |||||
| std::shared_ptr<Vocab> builder_vocab_; | |||||
| int64_t builder_top_k_; | |||||
| }; | |||||
| BuildVocabOp(std::shared_ptr<Vocab> vocab, std::vector<std::string> col_names, std::pair<int64_t, int64_t> freq_range, | |||||
| int64_t top_k, int32_t num_workers, int32_t op_connector_size); | |||||
| Status WorkerEntry(int32_t worker_id) override; | |||||
| // collect the work product from each worker | |||||
| Status CollectorThread(); | |||||
| Status EofReceived(int32_t) override { return Status::OK(); } | |||||
| Status EoeReceived(int32_t) override { return Status::OK(); } | |||||
| Status operator()() override; | |||||
| // Getter | |||||
| // @return the number of workers | |||||
| int32_t num_producers() const override { return 1; } | |||||
| // Getter | |||||
| // @return the number of threads consuming from the previous Connector | |||||
| int32_t num_consumers() const override { return 1; } | |||||
| Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildVocabOp"); } | |||||
| private: | |||||
| const int32_t interval_; | |||||
| std::shared_ptr<Vocab> vocab_; | |||||
| std::vector<std::string> col_names_; | |||||
| std::vector<int32_t> col_ids_; | |||||
| // pair = {min_f, max_f} | |||||
| // make sure that 0<= min_f < max_f <= int32_max in the builder | |||||
| std::pair<int64_t, int64_t> freq_range_; | |||||
| int64_t top_k_; // every thing means top_k_ == int32_max | |||||
| std::unique_ptr<ChildIterator> child_iterator_; // child iterator for fetching TensorRows 1 by 1 | |||||
| std::unique_ptr<Queue<TensorRow>> distributor_queue_; // master thread assigns each worker TensorRow via this | |||||
| std::unique_ptr<Queue<std::unique_ptr<std::unordered_map<std::string, int64_t>>>> collector_queue_; | |||||
| std::unordered_map<std::string, int64_t> word_cnt_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_ | |||||
| @@ -364,9 +364,7 @@ Status ImageFolderOp::startAsyncWalk() { | |||||
| } | } | ||||
| Status ImageFolderOp::LaunchThreadsAndInitOp() { | Status ImageFolderOp::LaunchThreadsAndInitOp() { | ||||
| if (tree_ == nullptr) { | |||||
| RETURN_STATUS_UNEXPECTED("tree_ not set"); | |||||
| } | |||||
| RETURN_UNEXPECTED_IF_NULL(tree_); | |||||
| // Registers QueueList and individual Queues for interrupt services | // Registers QueueList and individual Queues for interrupt services | ||||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | ||||
| RETURN_IF_NOT_OK(folder_name_queue_->Register(tree_->AllTasks())); | RETURN_IF_NOT_OK(folder_name_queue_->Register(tree_->AllTasks())); | ||||
| @@ -21,19 +21,23 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| Vocab::Vocab(std::unordered_map<WordType, WordIdType> word2id) { | |||||
| word2id_ = std::move(word2id); | |||||
| id2word_.resize(word2id_.size()); | |||||
| for (auto p : word2id_) { | |||||
| id2word_[p.second - kSpecialTokens::num_tokens] = p.first; | |||||
| } | |||||
| } | |||||
| Vocab::Vocab(std::unordered_map<WordType, WordIdType> word2id) { word2id_ = std::move(word2id); } | |||||
| WordIdType Vocab::Lookup(const WordType &word, WordIdType default_id) const { | WordIdType Vocab::Lookup(const WordType &word, WordIdType default_id) const { | ||||
| auto itr = word2id_.find(word); | auto itr = word2id_.find(word); | ||||
| return itr == word2id_.end() ? default_id : itr->second; | return itr == word2id_.end() ? default_id : itr->second; | ||||
| } | } | ||||
| WordType Vocab::Lookup(WordIdType id) const { | |||||
| WordType Vocab::Lookup(WordIdType id) { | |||||
| // this operation is most likely only done with since reverse lookup is only needed when training is done | |||||
| // hence, the worst case of inserting while keep looking up isn't likely to happen | |||||
| if (id2word_.size() != word2id_.size() && (id - kSpecialTokens::num_tokens >= id2word_.size())) { | |||||
| id2word_.clear(); | |||||
| id2word_.reserve(word2id_.size()); | |||||
| for (auto p : word2id_) { | |||||
| id2word_[p.second - kSpecialTokens::num_tokens] = p.first; | |||||
| } | |||||
| } | |||||
| if (id < kSpecialTokens::num_tokens) { | if (id < kSpecialTokens::num_tokens) { | ||||
| return reserved_token_str_[id]; | return reserved_token_str_[id]; | ||||
| } else if (id - kSpecialTokens::num_tokens >= id2word_.size()) { | } else if (id - kSpecialTokens::num_tokens >= id2word_.size()) { | ||||
| @@ -97,5 +101,11 @@ Status Vocab::BuildFromPyDict(const py::dict &words, std::shared_ptr<Vocab> *voc | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| const std::vector<WordType> Vocab::reserved_token_str_ = {"<pad>", "<unk>"}; | const std::vector<WordType> Vocab::reserved_token_str_ = {"<pad>", "<unk>"}; | ||||
| void Vocab::append_word(const std::string &word) { | |||||
| if (word2id_.find(word) == word2id_.end()) { | |||||
| word2id_[word] = word2id_.size() + kSpecialTokens::num_tokens; | |||||
| } | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -65,12 +65,19 @@ class Vocab { | |||||
| // reverse lookup, lookup the word based on its id | // reverse lookup, lookup the word based on its id | ||||
| // @param WordIdType id - word id to lookup to | // @param WordIdType id - word id to lookup to | ||||
| // @return WordType the word | // @return WordType the word | ||||
| WordType Lookup(WordIdType id) const; | |||||
| WordType Lookup(WordIdType id); | |||||
| // constructor, shouldn't be called directly, can't be private due to std::make_unique() | // constructor, shouldn't be called directly, can't be private due to std::make_unique() | ||||
| // @param std::unordered_map<WordType, WordIdType> map - sanitized word2id map | // @param std::unordered_map<WordType, WordIdType> map - sanitized word2id map | ||||
| explicit Vocab(std::unordered_map<WordType, WordIdType> map); | explicit Vocab(std::unordered_map<WordType, WordIdType> map); | ||||
| Vocab() = default; | |||||
| // add one word to vocab, increment it's index automatically | |||||
| // @param std::string & word - word to be added will skip if word already exists | |||||
| void append_word(const std::string &word); | |||||
| // destructor | |||||
| ~Vocab() = default; | ~Vocab() = default; | ||||
| // enum type that holds all special tokens, add more if needed | // enum type that holds all special tokens, add more if needed | ||||
| @@ -28,10 +28,10 @@ from .serializer_deserializer import serialize, deserialize, show, compare | |||||
| from .samplers import * | from .samplers import * | ||||
| from ..core.configuration import config, ConfigurationManager | from ..core.configuration import config, ConfigurationManager | ||||
| __all__ = ["config", "ConfigurationManager", "zip", | __all__ = ["config", "ConfigurationManager", "zip", | ||||
| "ImageFolderDatasetV2", "MnistDataset", | "ImageFolderDatasetV2", "MnistDataset", | ||||
| "MindDataset", "GeneratorDataset", "TFRecordDataset", | "MindDataset", "GeneratorDataset", "TFRecordDataset", | ||||
| "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", | "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", | ||||
| "VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", | |||||
| "VOCDataset", "CocoDataset", "TextFileDataset", "BuildVocabDataset", "Schema", "Schema", | |||||
| "DistributedSampler", "PKSampler", | |||||
| "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"] | "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"] | ||||
| @@ -42,8 +42,8 @@ from .iterators import DictIterator, TupleIterator | |||||
| from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ | from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ | ||||
| check_rename, check_numpyslicesdataset, \ | check_rename, check_numpyslicesdataset, \ | ||||
| check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | ||||
| check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset,\ | |||||
| check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat,\ | |||||
| check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ | |||||
| check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ | |||||
| check_split | check_split | ||||
| from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | ||||
| @@ -824,6 +824,29 @@ class Dataset: | |||||
| return ProjectDataset(self, columns) | return ProjectDataset(self, columns) | ||||
| def build_vocab(self, vocab, columns, freq_range, top_k): | |||||
| """ | |||||
| Build a vocab from a dataset. This would collect all the unique words in a dataset and return a vocab | |||||
| which contains top_k most frequent words (if top_k is specified) | |||||
| This function is not meant to be called directly by user. To build vocab, please use the function | |||||
| text.Vocab.from_dataset() | |||||
| Args: | |||||
| vocab(Vocab): vocab object | |||||
| columns(str or list, optional): column names to get words from. It can be a list of column names. | |||||
| (Default is None where all columns will be used. If any column isn't string type, will return error) | |||||
| freq_range(tuple, optional): A tuple of integers (min_frequency, max_frequency). Words within the frequency | |||||
| range would be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency | |||||
| can be None, which corresponds to 0/total_words separately (default is None, all words are included) | |||||
| top_k(int, optional): top_k > 0. Number of words to be built into vocab. top_k most frequent words are | |||||
| taken. top_k is taken after freq_range. If not enough top_k, all words will be taken. (default is None | |||||
| all words are included) | |||||
| Returns: | |||||
| BuildVocabDataset | |||||
| """ | |||||
| return BuildVocabDataset(self, vocab, columns, freq_range, top_k) | |||||
| def apply(self, apply_func): | def apply(self, apply_func): | ||||
| """ | """ | ||||
| Apply a function in this dataset. | Apply a function in this dataset. | ||||
| @@ -1483,6 +1506,7 @@ class BatchDataset(DatasetOp): | |||||
| for input_dataset in dataset.input: | for input_dataset in dataset.input: | ||||
| BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) | BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) | ||||
| class BatchInfo(CBatchInfo): | class BatchInfo(CBatchInfo): | ||||
| """ | """ | ||||
| The information object associates with the current batch of tensors. | The information object associates with the current batch of tensors. | ||||
| @@ -1506,6 +1530,7 @@ class BatchInfo(CBatchInfo): | |||||
| """ | """ | ||||
| return | return | ||||
| class BlockReleasePair: | class BlockReleasePair: | ||||
| """ | """ | ||||
| The blocking condition class used by SyncWaitDataset. | The blocking condition class used by SyncWaitDataset. | ||||
| @@ -1514,6 +1539,7 @@ class BlockReleasePair: | |||||
| init_release_rows (int): Number of lines to allow through the pipeline. | init_release_rows (int): Number of lines to allow through the pipeline. | ||||
| callback (function): The callback funciton that will be called when release is called. | callback (function): The callback funciton that will be called when release is called. | ||||
| """ | """ | ||||
| def __init__(self, init_release_rows, callback=None): | def __init__(self, init_release_rows, callback=None): | ||||
| if isinstance(init_release_rows, int) and init_release_rows <= 0: | if isinstance(init_release_rows, int) and init_release_rows <= 0: | ||||
| raise ValueError("release_rows need to be greater than 0.") | raise ValueError("release_rows need to be greater than 0.") | ||||
| @@ -1696,6 +1722,7 @@ class _PythonCallable: | |||||
| """ | """ | ||||
| Internal python function wrapper for multiprocessing pyfunc. | Internal python function wrapper for multiprocessing pyfunc. | ||||
| """ | """ | ||||
| def __init__(self, py_callable, idx, pool=None): | def __init__(self, py_callable, idx, pool=None): | ||||
| # Original python callable from user. | # Original python callable from user. | ||||
| self.py_callable = py_callable | self.py_callable = py_callable | ||||
| @@ -2593,7 +2620,7 @@ class MindDataset(SourceDataset): | |||||
| if sampler is not None: | if sampler is not None: | ||||
| if isinstance(sampler, samplers.SubsetRandomSampler) is False and \ | if isinstance(sampler, samplers.SubsetRandomSampler) is False and \ | ||||
| isinstance(sampler, samplers.PKSampler) is False: | |||||
| isinstance(sampler, samplers.PKSampler) is False: | |||||
| raise ValueError("the sampler is not supported yet.") | raise ValueError("the sampler is not supported yet.") | ||||
| # sampler exclusive | # sampler exclusive | ||||
| @@ -2859,6 +2886,7 @@ class _GeneratorWorker(multiprocessing.Process): | |||||
| """ | """ | ||||
| Worker process for multiprocess Generator. | Worker process for multiprocess Generator. | ||||
| """ | """ | ||||
| def __init__(self, dataset, eoe): | def __init__(self, dataset, eoe): | ||||
| self.idx_queue = multiprocessing.Queue(16) | self.idx_queue = multiprocessing.Queue(16) | ||||
| self.res_queue = multiprocessing.Queue(16) | self.res_queue = multiprocessing.Queue(16) | ||||
| @@ -3686,6 +3714,7 @@ class RandomDataset(SourceDataset): | |||||
| def is_sharded(self): | def is_sharded(self): | ||||
| return False | return False | ||||
| class Schema: | class Schema: | ||||
| """ | """ | ||||
| Class to represent a schema of dataset. | Class to represent a schema of dataset. | ||||
| @@ -4383,6 +4412,7 @@ class _NumpySlicesDataset: | |||||
| """ | """ | ||||
| Mainly for dealing with several kinds of format of python data, and return one row each time. | Mainly for dealing with several kinds of format of python data, and return one row each time. | ||||
| """ | """ | ||||
| def __init__(self, data, column_list=None): | def __init__(self, data, column_list=None): | ||||
| self.column_list = None | self.column_list = None | ||||
| # Convert dict data into tuple | # Convert dict data into tuple | ||||
| @@ -4525,6 +4555,7 @@ class NumpySlicesDataset(GeneratorDataset): | |||||
| >>> df = pd.read_csv("file.csv") | >>> df = pd.read_csv("file.csv") | ||||
| >>> dataset4 = ds.NumpySlicesDataset(dict(df), shuffle=False) | >>> dataset4 = ds.NumpySlicesDataset(dict(df), shuffle=False) | ||||
| """ | """ | ||||
| @check_numpyslicesdataset | @check_numpyslicesdataset | ||||
| def __init__(self, data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, | def __init__(self, data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, | ||||
| sampler=None, num_shards=None, shard_id=None): | sampler=None, num_shards=None, shard_id=None): | ||||
| @@ -4532,3 +4563,61 @@ class NumpySlicesDataset(GeneratorDataset): | |||||
| super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples, | super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples, | ||||
| num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler, | num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler, | ||||
| num_shards=num_shards, shard_id=shard_id) | num_shards=num_shards, shard_id=shard_id) | ||||
| class BuildVocabDataset(DatasetOp): | |||||
| """ | |||||
| Build a vocab from a dataset. This would collect all the unique words in a dataset and return a vocab | |||||
| which contains top_k most frequent words (if top_k is specified) | |||||
| This function is not meant to be called directly by user. To build vocab, please use the function | |||||
| text.Vocab.from_dataset() | |||||
| Args: | |||||
| vocab(Vocab): vocab object | |||||
| columns(str or list, optional): column names to get words from. It can be a list of column names. | |||||
| (Default is None where all columns will be used. If any column isn't string type, will return error) | |||||
| freq_range(tuple, optional): A tuple of integers (min_frequency, max_frequency). Words within the frequency | |||||
| range would be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency | |||||
| can be None, which corresponds to 0/total_words separately (default is None, all words are included) | |||||
| top_k(int, optional): top_k > 0. Number of words to be built into vocab. top_k most frequent words are | |||||
| taken. top_k is taken after freq_range. If not enough top_k, all words will be taken. (default is None | |||||
| all words are included) | |||||
| Returns: | |||||
| BuildVocabDataset | |||||
| """ | |||||
| def __init__(self, input_dataset, vocab, columns, freq_range, top_k, prefetch_size=None): | |||||
| super().__init__() | |||||
| self.columns = columns | |||||
| self.input.append(input_dataset) | |||||
| self.prefetch_size = prefetch_size | |||||
| self.vocab = vocab | |||||
| self.freq_range = freq_range | |||||
| self.top_k = top_k | |||||
| input_dataset.output.append(self) | |||||
| def get_args(self): | |||||
| args = super().get_args() | |||||
| args["columns"] = self.columns | |||||
| args["vocab"] = self.vocab | |||||
| args["freq_range"] = self.freq_range | |||||
| args["prefetch_size"] = self.prefetch_size | |||||
| args["top_k"] = self.top_k | |||||
| return args | |||||
| def __deepcopy__(self, memodict): | |||||
| if id(self) in memodict: | |||||
| return memodict[id(self)] | |||||
| cls = self.__class__ | |||||
| new_op = cls.__new__(cls) | |||||
| memodict[id(self)] = new_op | |||||
| new_op.input = copy.deepcopy(self.input, memodict) | |||||
| new_op.columns = copy.deepcopy(self.columns, memodict) | |||||
| new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) | |||||
| new_op.prefetch_size = copy.deepcopy(self.prefetch_size, memodict) | |||||
| new_op.output = copy.deepcopy(self.output, memodict) | |||||
| new_op.freq_range = copy.deepcopy(self.freq_range, memodict) | |||||
| new_op.top_k = copy.deepcopy(self.top_k, memodict) | |||||
| new_op.vocab = self.vocab | |||||
| return new_op | |||||
| @@ -177,6 +177,8 @@ class Iterator: | |||||
| op_type = OpName.RANDOMDATA | op_type = OpName.RANDOMDATA | ||||
| elif isinstance(dataset, de.TextFileDataset): | elif isinstance(dataset, de.TextFileDataset): | ||||
| op_type = OpName.TEXTFILE | op_type = OpName.TEXTFILE | ||||
| elif isinstance(dataset, de.BuildVocabDataset): | |||||
| op_type = OpName.BUILDVOCAB | |||||
| else: | else: | ||||
| raise ValueError("Unsupported DatasetOp") | raise ValueError("Unsupported DatasetOp") | ||||
| @@ -16,20 +16,43 @@ Some basic function for nlp | |||||
| """ | """ | ||||
| from enum import IntEnum | from enum import IntEnum | ||||
| import copy | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore._c_dataengine as cde | import mindspore._c_dataengine as cde | ||||
| from .validators import check_from_file, check_from_list, check_from_dict | |||||
| from .validators import check_from_file, check_from_list, check_from_dict, check_from_dataset | |||||
| class Vocab(cde.Vocab): | class Vocab(cde.Vocab): | ||||
| """ | """ | ||||
| Vocab object that is used for lookup word | Vocab object that is used for lookup word | ||||
| Args: | |||||
| """ | """ | ||||
| def __init__(self): | |||||
| pass | |||||
| @classmethod | |||||
| @check_from_dataset | |||||
| def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None): | |||||
| """ | |||||
| Build a vocab from a dataset. This would collect all the unique words in a dataset and return a vocab | |||||
| which contains top_k most frequent words (if top_k is specified) | |||||
| Args: | |||||
| dataset(Dataset): dataset to build vocab from. | |||||
| columns(str or list, optional): column names to get words from. It can be a list of column names. | |||||
| (Default is None where all columns will be used. If any column isn't string type, will return error) | |||||
| freq_range(tuple, optional): A tuple of integers (min_frequency, max_frequency). Words within the frequency | |||||
| range would be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency | |||||
| can be None, which corresponds to 0/total_words separately (default is None, all words are included) | |||||
| top_k(int, optional): top_k > 0. Number of words to be built into vocab. top_k most frequent words are | |||||
| taken. top_k is taken after freq_range. If not enough top_k, all words will be taken. (default is None | |||||
| all words are included) | |||||
| return: | |||||
| text.Vocab: vocab object built from dataset. | |||||
| """ | |||||
| vocab = Vocab() | |||||
| root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k) | |||||
| for d in root.create_dict_iterator(): | |||||
| if d is not None: | |||||
| raise ValueError("from_dataset should receive data other than None") | |||||
| return vocab | |||||
| @classmethod | @classmethod | ||||
| @check_from_list | @check_from_list | ||||
| @@ -186,6 +186,62 @@ def check_jieba_add_dict(method): | |||||
| return new_method | return new_method | ||||
| def check_from_dataset(method): | |||||
| """A wrapper that wrap a parameter checker to the original function(crop operation).""" | |||||
| # def from_dataset(cls, dataset, columns, freq_range=None, top_k=None): | |||||
| @wraps(method) | |||||
| def new_method(self, *args, **kwargs): | |||||
| dataset, columns, freq_range, top_k = (list(args) + 4 * [None])[:4] | |||||
| if "dataset" in kwargs: | |||||
| dataset = kwargs.get("dataset") | |||||
| if "columns" in kwargs: | |||||
| columns = kwargs.get("columns") | |||||
| if "freq_range" in kwargs: | |||||
| freq_range = kwargs.get("freq_range") | |||||
| if "top_k" in kwargs: | |||||
| top_k = kwargs.get("top_k") | |||||
| if columns is None: | |||||
| columns = [] | |||||
| if not isinstance(columns, list): | |||||
| columns = [columns] | |||||
| for column in columns: | |||||
| if not isinstance(column, str): | |||||
| raise ValueError("columns need to be a list of strings") | |||||
| if freq_range is None: | |||||
| freq_range = (None, None) | |||||
| if not isinstance(freq_range, tuple) or len(freq_range) != 2: | |||||
| raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None") | |||||
| for num in freq_range: | |||||
| if num is not None and (not isinstance(num, int)): | |||||
| raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None") | |||||
| if isinstance(freq_range[0], int) and isinstance(freq_range[1], int): | |||||
| if freq_range[0] > freq_range[1] or freq_range[0] < 0: | |||||
| raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)") | |||||
| if top_k is not None and (not isinstance(top_k, int)): | |||||
| raise ValueError("top_k needs to be a positive integer") | |||||
| if isinstance(top_k, int) and top_k <= 0: | |||||
| raise ValueError("top_k needs to be a positive integer") | |||||
| kwargs["dataset"] = dataset | |||||
| kwargs["columns"] = columns | |||||
| kwargs["freq_range"] = freq_range | |||||
| kwargs["top_k"] = top_k | |||||
| return method(self, **kwargs) | |||||
| return new_method | |||||
| def check_ngram(method): | def check_ngram(method): | ||||
| """A wrapper that wrap a parameter checker to the original function(crop operation).""" | """A wrapper that wrap a parameter checker to the original function(crop operation).""" | ||||
| @@ -0,0 +1,112 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """ | |||||
| Testing from_dataset in mindspore.dataset | |||||
| """ | |||||
| import numpy as np | |||||
| import mindspore.dataset as ds | |||||
| import mindspore.dataset.text as text | |||||
| def test_demo_basic_from_dataset(): | |||||
| """ this is a tutorial on how from_dataset should be used in a normal use case""" | |||||
| data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False) | |||||
| vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None) | |||||
| data = data.map(input_columns=["text"], operations=text.Lookup(vocab)) | |||||
| res = [] | |||||
| for d in data.create_dict_iterator(): | |||||
| res.append(d["text"].item()) | |||||
| assert res == [4, 5, 3, 6, 7, 2] | |||||
| def test_demo_basic_from_dataset_with_tokenizer(): | |||||
| """ this is a tutorial on how from_dataset should be used in a normal use case with tokenizer""" | |||||
| data = ds.TextFileDataset("../data/dataset/testTokenizerData/1.txt", shuffle=False) | |||||
| data = data.map(input_columns=["text"], operations=text.UnicodeCharTokenizer()) | |||||
| vocab = text.Vocab.from_dataset(data, None, freq_range=None, top_k=None) | |||||
| data = data.map(input_columns=["text"], operations=text.Lookup(vocab)) | |||||
| res = [] | |||||
| for d in data.create_dict_iterator(): | |||||
| res.append(list(d["text"])) | |||||
| assert res == [[13, 3, 7, 14, 9, 17, 3, 2, 19, 9, 2, 11, 3, 4, 16, 4, 8, 6, 5], [21, 20, 10, 25, 23, 26], | |||||
| [24, 22, 10, 12, 8, 6, 7, 4, 18, 15, 5], [2, 2]] | |||||
| def test_from_dataset(): | |||||
| """ test build vocab with generator dataset """ | |||||
| def gen_corpus(): | |||||
| # key: word, value: number of occurrences, reason for using letters is so their order is apparent | |||||
| corpus = {"Z": 4, "Y": 4, "X": 4, "W": 3, "U": 3, "V": 2, "T": 1} | |||||
| for k, v in corpus.items(): | |||||
| yield (np.array([k] * v, dtype='S'),) | |||||
| def test_config(freq_range, top_k): | |||||
| corpus_dataset = ds.GeneratorDataset(gen_corpus, column_names=["text"]) | |||||
| vocab = text.Vocab.from_dataset(corpus_dataset, None, freq_range, top_k) | |||||
| corpus_dataset = corpus_dataset.map(input_columns="text", operations=text.Lookup(vocab)) | |||||
| res = [] | |||||
| for d in corpus_dataset.create_dict_iterator(): | |||||
| res.append(list(d["text"])) | |||||
| return res | |||||
| # take words whose frequency is with in [3,4] order them alphabetically for words with the same frequency | |||||
| test1_res = test_config(freq_range=(3, 4), top_k=4) | |||||
| assert test1_res == [[4, 4, 4, 4], [3, 3, 3, 3], [2, 2, 2, 2], [1, 1, 1], [5, 5, 5], [1, 1], [1]], str(test1_res) | |||||
| # test words with frequency range [2,inf], only the last word will be filtered out | |||||
| test2_res = test_config((2, None), None) | |||||
| assert test2_res == [[4, 4, 4, 4], [3, 3, 3, 3], [2, 2, 2, 2], [6, 6, 6], [5, 5, 5], [7, 7], [1]], str(test2_res) | |||||
| # test filter only by top_k | |||||
| test3_res = test_config(None, 4) | |||||
| assert test3_res == [[4, 4, 4, 4], [3, 3, 3, 3], [2, 2, 2, 2], [1, 1, 1], [5, 5, 5], [1, 1], [1]], str(test3_res) | |||||
| # test filtering out the most frequent | |||||
| test4_res = test_config((None, 3), 100) | |||||
| assert test4_res == [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [3, 3, 3], [2, 2, 2], [4, 4], [5]], str(test4_res) | |||||
| # test top_k == 1 | |||||
| test5_res = test_config(None, 1) | |||||
| assert test5_res == [[1, 1, 1, 1], [1, 1, 1, 1], [2, 2, 2, 2], [1, 1, 1], [1, 1, 1], [1, 1], [1]], str(test5_res) | |||||
| # test min_frequency == max_frequency | |||||
| test6_res = test_config((4, 4), None) | |||||
| assert test6_res == [[4, 4, 4, 4], [3, 3, 3, 3], [2, 2, 2, 2], [1, 1, 1], [1, 1, 1], [1, 1], [1]], str(test6_res) | |||||
| def test_from_dataset_exceptions(): | |||||
| """ test various exceptions during that are checked in validator """ | |||||
| def test_config(columns, freq_range, top_k, s): | |||||
| try: | |||||
| data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False) | |||||
| vocab = text.Vocab.from_dataset(data, columns, freq_range, top_k) | |||||
| assert isinstance(vocab.text.Vocab) | |||||
| except ValueError as e: | |||||
| assert s in str(e), str(e) | |||||
| test_config("text", (), 1, "freq_range needs to be either None or a tuple of 2 integers") | |||||
| test_config("text", (2, 3), 1.2345, "top_k needs to be a positive integer") | |||||
| test_config(23, (2, 3), 1.2345, "columns need to be a list of strings") | |||||
| test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b") | |||||
| test_config("text", (2, 3), 0, "top_k needs to be a positive integer") | |||||
| test_config([123], (2, 3), 0, "columns need to be a list of strings") | |||||
| if __name__ == '__main__': | |||||
| test_demo_basic_from_dataset() | |||||
| test_from_dataset() | |||||
| test_from_dataset_exceptions() | |||||
| test_demo_basic_from_dataset_with_tokenizer() | |||||
| @@ -13,7 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """ | """ | ||||
| Testing NgramOP in DE | |||||
| Testing Ngram in mindspore.dataset | |||||
| """ | """ | ||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| import mindspore.dataset.text as nlp | import mindspore.dataset.text as nlp | ||||