Browse Source

implemented from_dataset

fix complie error

more tests

address CI complains

fix ci

adress review comments

address review cmts
tags/v0.5.0-beta
Zirui Wu 5 years ago
parent
commit
880ce5ea26
17 changed files with 692 additions and 24 deletions
  1. +33
    -1
      mindspore/ccsrc/dataset/api/de_pipeline.cc
  2. +4
    -1
      mindspore/ccsrc/dataset/api/de_pipeline.h
  3. +2
    -0
      mindspore/ccsrc/dataset/api/python_bindings.cc
  4. +1
    -0
      mindspore/ccsrc/dataset/core/client.h
  5. +1
    -0
      mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt
  6. +179
    -0
      mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.cc
  7. +153
    -0
      mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.h
  8. +1
    -3
      mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc
  9. +18
    -8
      mindspore/ccsrc/dataset/text/vocab.cc
  10. +8
    -1
      mindspore/ccsrc/dataset/text/vocab.h
  11. +2
    -2
      mindspore/dataset/engine/__init__.py
  12. +92
    -3
      mindspore/dataset/engine/datasets.py
  13. +2
    -0
      mindspore/dataset/engine/iterators.py
  14. +27
    -4
      mindspore/dataset/text/utils.py
  15. +56
    -0
      mindspore/dataset/text/validators.py
  16. +112
    -0
      tests/ut/python/dataset/test_from_dataset.py
  17. +1
    -1
      tests/ut/python/dataset/test_ngram_op.py

+ 33
- 1
mindspore/ccsrc/dataset/api/de_pipeline.cc View File

@@ -71,7 +71,8 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
{kCifar100, &DEPipeline::ParseCifar100Op},
{kCelebA, &DEPipeline::ParseCelebAOp},
{kRandomData, &DEPipeline::ParseRandomDataOp},
{kTextFile, &DEPipeline::ParseTextFileOp}};
{kTextFile, &DEPipeline::ParseTextFileOp},
{kBuildVocab, &DEPipeline::ParseBuildVocabOp}};

DEPipeline::DEPipeline() : iterator_(nullptr) {
try {
@@ -1235,5 +1236,36 @@ Status DEPipeline::ParsePadInfo(py::handle value, PadInfo *pad_info) {
}
return Status::OK();
}
Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
std::shared_ptr<BuildVocabOp::Builder> builder = std::make_shared<BuildVocabOp::Builder>();
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "freq_range") {
py::tuple tp = py::reinterpret_borrow<py::tuple>(value);
if (!tp[0].is_none()) (void)builder->SetMinFreq(py::reinterpret_borrow<py::int_>(tp[0]));
if (!tp[1].is_none()) (void)builder->SetMaxFreq(py::reinterpret_borrow<py::int_>(tp[1]));
}
if (key == "top_k") {
builder->SetTopK(py::reinterpret_borrow<py::int_>(value));
}
if (key == "columns") {
(void)builder->SetColumnNames(ToStringVector(value));
}
if (key == "vocab") {
(void)builder->SetVocab(value.cast<std::shared_ptr<Vocab>>());
}
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
}
}
}
std::shared_ptr<BuildVocabOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*ptr = op;
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 4
- 1
mindspore/ccsrc/dataset/api/de_pipeline.h View File

@@ -63,7 +63,8 @@ enum OpName {
kCifar100,
kCelebA,
kRandomData,
kTextFile
kTextFile,
kBuildVocab
};

// The C++ binder class that we expose to the python script.
@@ -163,6 +164,8 @@ class DEPipeline {

Status ParseTextFileOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);

Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);

private:
// Execution tree that links the dataset operators.
std::shared_ptr<ExecutionTree> tree_;


+ 2
- 0
mindspore/ccsrc/dataset/api/python_bindings.cc View File

@@ -514,6 +514,7 @@ void bindInfoObjects(py::module *m) {

void bindVocabObjects(py::module *m) {
(void)py::class_<Vocab, std::shared_ptr<Vocab>>(*m, "Vocab")
.def(py::init<>())
.def_static("from_list",
[](const py::list &words) {
std::shared_ptr<Vocab> v;
@@ -624,6 +625,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("CIFAR10", OpName::kCifar10)
.value("CIFAR100", OpName::kCifar100)
.value("RANDOMDATA", OpName::kRandomData)
.value("BUILDVOCAB", OpName::kBuildVocab)
.value("CELEBA", OpName::kCelebA)
.value("TEXTFILE", OpName::kTextFile);



+ 1
- 0
mindspore/ccsrc/dataset/core/client.h View File

@@ -27,6 +27,7 @@
#include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/datasetops/barrier_op.h"
#include "dataset/engine/datasetops/batch_op.h"
#include "dataset/engine/datasetops/build_vocab_op.h"
#include "dataset/engine/datasetops/dataset_op.h"
#include "dataset/engine/datasetops/device_queue_op.h"
#include "dataset/engine/datasetops/map_op.h"


+ 1
- 0
mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt View File

@@ -19,5 +19,6 @@ add_library(engine-datasetops OBJECT
zip_op.cc
concat_op.cc
filter_op.cc
build_vocab_op.cc
)


+ 179
- 0
mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.cc View File

@@ -0,0 +1,179 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "dataset/engine/datasetops/build_vocab_op.h"

#include <algorithm>
#include <limits>
#include <string>
#include <unordered_map>
#include <utility>
#include "dataset/core/config_manager.h"

namespace mindspore {
namespace dataset {

BuildVocabOp::BuildVocabOp(std::shared_ptr<Vocab> vocab, std::vector<std::string> col_names,
std::pair<int64_t, int64_t> freq_r, int64_t top_k, int32_t num_workers, int32_t op_conn_size)
: ParallelOp(num_workers, op_conn_size),
interval_(op_conn_size * num_workers),
vocab_(vocab),
col_names_(col_names),
freq_range_(freq_r),
top_k_(top_k) {
// init two queues for thread sync
distributor_queue_ = std::make_unique<Queue<TensorRow>>(num_workers * op_conn_size);
collector_queue_ =
std::make_unique<Queue<std::unique_ptr<std::unordered_map<std::string, int64_t>>>>(num_workers * op_conn_size);
}

Status BuildVocabOp::WorkerEntry(int32_t worker_id) {
TaskManager::FindMe()->Post();
TensorRow new_row;
RETURN_IF_NOT_OK(distributor_queue_->PopFront(&new_row));
std::unique_ptr<std::unordered_map<std::string, int64_t>> wrkr_map =
std::make_unique<std::unordered_map<std::string, int64_t>>();
int32_t row_cnt = 0;
while (!new_row.empty()) {
for (int32_t col : col_ids_) {
CHECK_FAIL_RETURN_UNEXPECTED(!new_row[col]->type().IsNumeric(), "from_dataset only works on string columns");
for (auto itr = new_row[col]->begin<std::string_view>(); itr != new_row[col]->end<std::string_view>(); itr++) {
(*wrkr_map)[std::string(*itr)] += 1;
}
}
row_cnt++; // row is processed by this point
if ((row_cnt % interval_ == 0) && ((row_cnt / interval_) % num_workers_ == worker_id) && (!wrkr_map->empty())) {
RETURN_IF_NOT_OK(collector_queue_->Add(std::move(wrkr_map)));
wrkr_map = std::make_unique<std::unordered_map<std::string, int64_t>>();
}
RETURN_IF_NOT_OK(distributor_queue_->PopFront(&new_row));
}
// clean up
if (!wrkr_map->empty()) {
RETURN_IF_NOT_OK(collector_queue_->Add(std::move(wrkr_map)));
}
// empty map as quit signal
RETURN_IF_NOT_OK(collector_queue_->Add(std::make_unique<std::unordered_map<std::string, int64_t>>()));
return Status::OK();
}

Status BuildVocabOp::operator()() {
// launch the collector thread
RETURN_UNEXPECTED_IF_NULL(tree_);
RETURN_IF_NOT_OK(distributor_queue_->Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(collector_queue_->Register(tree_->AllTasks()));
// launch worker threads and collector thread
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&BuildVocabOp::WorkerEntry, this, std::placeholders::_1)));
RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("collector", std::bind(&BuildVocabOp::CollectorThread, this)));
TaskManager::FindMe()->Post();
child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0);
TensorRow new_row;
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
RETURN_IF_NOT_OK(AssignColMapFromChild());
if (!col_names_.empty()) {
col_ids_.reserve(col_names_.size());
for (std::string col : col_names_) {
auto itr = column_name_id_map_.find(col);
CHECK_FAIL_RETURN_UNEXPECTED(itr != column_name_id_map_.end(), col + " column doesn't exist");
col_ids_.push_back(itr->second);
}
} else {
col_ids_.reserve(column_name_id_map_.size());
for (const auto &p : column_name_id_map_) {
col_ids_.push_back(p.second);
}
}
bool eoe_warning = false; // give out warning if receive more than 1 eoe
while (child_iterator_->eof_handled() == false) {
while (new_row.empty() == false) {
RETURN_IF_NOT_OK(distributor_queue_->EmplaceBack(new_row));
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
}
CHECK_FAIL_RETURN_UNEXPECTED(!eoe_warning, "no op should be after from_dataset (repeat detected)");
eoe_warning = true;
}

// tell all workers to quit
for (int32_t wrkr_id = 0; wrkr_id < num_workers_; wrkr_id++) {
RETURN_IF_NOT_OK(distributor_queue_->EmplaceBack(TensorRow()));
}
return Status::OK();
}

Status BuildVocabOp::CollectorThread() {
TaskManager::FindMe()->Post();
int32_t num_quited_worker = 0;
std::unique_ptr<std::unordered_map<std::string, int64_t>> wrkr_map;
while (num_quited_worker != num_workers_) {
RETURN_IF_NOT_OK(collector_queue_->PopFront(&wrkr_map));
RETURN_UNEXPECTED_IF_NULL(wrkr_map);
if (!wrkr_map->empty()) {
for (const auto &wd : *wrkr_map) word_cnt_[wd.first] += wd.second;
} else {
++num_quited_worker;
}
} // all frequencies are obtained
CHECK_FAIL_RETURN_UNEXPECTED(!word_cnt_.empty(), "word_cnt is empty");
std::vector<std::string> words;
// make sure enough is reserved
words.reserve(wrkr_map->size());

for (auto it = word_cnt_.begin(); it != word_cnt_.end();) {
if (it->second >= freq_range_.first && it->second <= freq_range_.second) {
words.push_back(it->first);
it++;
} else {
it = word_cnt_.erase(it);
}
}
int64_t num_words = std::min(static_cast<int64_t>(words.size()), top_k_);
// this would take the top-k most frequent words
std::partial_sort(words.begin(), words.begin() + num_words, words.end(),
[this](const std::string &w1, const std::string &w2) {
int64_t f1 = word_cnt_[w1], f2 = word_cnt_[w2];
return f1 == f2 ? w1 < w2 : f1 > f2;
});
for (int64_t i = 0; i < num_words; i++) {
vocab_->append_word(words[i]);
}
RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)));
RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)));
// then use std::nth_element to partial sort
return Status::OK();
}

Status BuildVocabOp::Builder::Build(std::shared_ptr<BuildVocabOp> *op) {
CHECK_FAIL_RETURN_UNEXPECTED(builder_num_workers_ > 0, "builder num_workers need to be greater than 0");
CHECK_FAIL_RETURN_UNEXPECTED(builder_top_k_ > 0, "top_k needs to be positive number");
CHECK_FAIL_RETURN_UNEXPECTED(builder_max_freq_ >= builder_min_freq_ && builder_min_freq_ >= 0,
"frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)");
(*op) = std::make_shared<BuildVocabOp>(builder_vocab_, builder_col_names_,
std::make_pair(builder_min_freq_, builder_max_freq_), builder_top_k_,
builder_num_workers_, builder_connector_size_);
return Status::OK();
}

BuildVocabOp::Builder::Builder()
: builder_top_k_(std::numeric_limits<int64_t>::max()),
builder_min_freq_(0),
builder_max_freq_(std::numeric_limits<int64_t>::max()) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_num_workers_ = cfg->num_parallel_workers();
builder_connector_size_ = cfg->op_connector_size();
}
} // namespace dataset
} // namespace mindspore

+ 153
- 0
mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.h View File

@@ -0,0 +1,153 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_
#define DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_

#include <vector>
#include <memory>
#include <unordered_map>
#include <string>
#include <utility>

#include "dataset/core/tensor.h"
#include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/datasetops/parallel_op.h"
#include "dataset/text/vocab.h"
#include "dataset/util/queue.h"
#include "dataset/util/status.h"

namespace mindspore {
namespace dataset {
class BuildVocabOp : public ParallelOp {
public:
class Builder {
public:
Builder();

// Destructor.
~Builder() = default;

// Setter method
// @param int32_t size
// @return Builder setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t size) {
builder_connector_size_ = size;
return *this;
}

// Setter method
// @param int32_t num_workers
// @return Builder setter method returns reference to the builder.
Builder &SetNumWorkers(int32_t num_workers) {
builder_num_workers_ = num_workers;
return *this;
}

// Setter method
// @param int64_t top_k
// @return Builder setter method returns reference to the builder.
Builder &SetTopK(int64_t top_k) {
builder_top_k_ = top_k;
return *this;
}

// Setter method
// @param int64_t min_freq
// @return Builder setter method returns reference to the builder.
Builder &SetMinFreq(int64_t min_freq) {
builder_min_freq_ = min_freq;
return *this;
}

// Setter method
// @param int64_t max_freq
// @return Builder setter method returns reference to the builder.
Builder &SetMaxFreq(int64_t max_freq) {
builder_max_freq_ = max_freq;
return *this;
}

// set columns names
// @param const std::vector<std::string> & col_names - name of columns to get words
// @return Builder & reference to builder class object
Builder &SetColumnNames(const std::vector<std::string> &col_names) {
builder_col_names_ = col_names;
return *this;
}

// set vocab object
Builder &SetVocab(std::shared_ptr<Vocab> vocab) {
builder_vocab_ = vocab;
return *this;
}

// The builder "build" method creates the final object.
// @param std::shared_ptr<BuildVocabOp> *op - DatasetOp
// @return - The error code return
Status Build(std::shared_ptr<BuildVocabOp> *op);

private:
int32_t builder_num_workers_;
int32_t builder_connector_size_;
int64_t builder_min_freq_;
int64_t builder_max_freq_;
std::vector<std::string> builder_col_names_;
std::shared_ptr<Vocab> builder_vocab_;
int64_t builder_top_k_;
};

BuildVocabOp(std::shared_ptr<Vocab> vocab, std::vector<std::string> col_names, std::pair<int64_t, int64_t> freq_range,
int64_t top_k, int32_t num_workers, int32_t op_connector_size);

Status WorkerEntry(int32_t worker_id) override;

// collect the work product from each worker
Status CollectorThread();

Status EofReceived(int32_t) override { return Status::OK(); }

Status EoeReceived(int32_t) override { return Status::OK(); }

Status operator()() override;

// Getter
// @return the number of workers
int32_t num_producers() const override { return 1; }

// Getter
// @return the number of threads consuming from the previous Connector
int32_t num_consumers() const override { return 1; }

Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildVocabOp"); }

private:
const int32_t interval_;
std::shared_ptr<Vocab> vocab_;
std::vector<std::string> col_names_;
std::vector<int32_t> col_ids_;
// pair = {min_f, max_f}
// make sure that 0<= min_f < max_f <= int32_max in the builder
std::pair<int64_t, int64_t> freq_range_;

int64_t top_k_; // every thing means top_k_ == int32_max
std::unique_ptr<ChildIterator> child_iterator_; // child iterator for fetching TensorRows 1 by 1
std::unique_ptr<Queue<TensorRow>> distributor_queue_; // master thread assigns each worker TensorRow via this
std::unique_ptr<Queue<std::unique_ptr<std::unordered_map<std::string, int64_t>>>> collector_queue_;
std::unordered_map<std::string, int64_t> word_cnt_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_

+ 1
- 3
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc View File

@@ -364,9 +364,7 @@ Status ImageFolderOp::startAsyncWalk() {
}

Status ImageFolderOp::LaunchThreadsAndInitOp() {
if (tree_ == nullptr) {
RETURN_STATUS_UNEXPECTED("tree_ not set");
}
RETURN_UNEXPECTED_IF_NULL(tree_);
// Registers QueueList and individual Queues for interrupt services
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(folder_name_queue_->Register(tree_->AllTasks()));


+ 18
- 8
mindspore/ccsrc/dataset/text/vocab.cc View File

@@ -21,19 +21,23 @@

namespace mindspore {
namespace dataset {
Vocab::Vocab(std::unordered_map<WordType, WordIdType> word2id) {
word2id_ = std::move(word2id);
id2word_.resize(word2id_.size());
for (auto p : word2id_) {
id2word_[p.second - kSpecialTokens::num_tokens] = p.first;
}
}
Vocab::Vocab(std::unordered_map<WordType, WordIdType> word2id) { word2id_ = std::move(word2id); }

WordIdType Vocab::Lookup(const WordType &word, WordIdType default_id) const {
auto itr = word2id_.find(word);
return itr == word2id_.end() ? default_id : itr->second;
}
WordType Vocab::Lookup(WordIdType id) const {

WordType Vocab::Lookup(WordIdType id) {
// this operation is most likely only done with since reverse lookup is only needed when training is done
// hence, the worst case of inserting while keep looking up isn't likely to happen
if (id2word_.size() != word2id_.size() && (id - kSpecialTokens::num_tokens >= id2word_.size())) {
id2word_.clear();
id2word_.reserve(word2id_.size());
for (auto p : word2id_) {
id2word_[p.second - kSpecialTokens::num_tokens] = p.first;
}
}
if (id < kSpecialTokens::num_tokens) {
return reserved_token_str_[id];
} else if (id - kSpecialTokens::num_tokens >= id2word_.size()) {
@@ -97,5 +101,11 @@ Status Vocab::BuildFromPyDict(const py::dict &words, std::shared_ptr<Vocab> *voc
return Status::OK();
}
const std::vector<WordType> Vocab::reserved_token_str_ = {"<pad>", "<unk>"};

void Vocab::append_word(const std::string &word) {
if (word2id_.find(word) == word2id_.end()) {
word2id_[word] = word2id_.size() + kSpecialTokens::num_tokens;
}
}
} // namespace dataset
} // namespace mindspore

+ 8
- 1
mindspore/ccsrc/dataset/text/vocab.h View File

@@ -65,12 +65,19 @@ class Vocab {
// reverse lookup, lookup the word based on its id
// @param WordIdType id - word id to lookup to
// @return WordType the word
WordType Lookup(WordIdType id) const;
WordType Lookup(WordIdType id);

// constructor, shouldn't be called directly, can't be private due to std::make_unique()
// @param std::unordered_map<WordType, WordIdType> map - sanitized word2id map
explicit Vocab(std::unordered_map<WordType, WordIdType> map);

Vocab() = default;

// add one word to vocab, increment it's index automatically
// @param std::string & word - word to be added will skip if word already exists
void append_word(const std::string &word);

// destructor
~Vocab() = default;

// enum type that holds all special tokens, add more if needed


+ 2
- 2
mindspore/dataset/engine/__init__.py View File

@@ -28,10 +28,10 @@ from .serializer_deserializer import serialize, deserialize, show, compare
from .samplers import *
from ..core.configuration import config, ConfigurationManager


__all__ = ["config", "ConfigurationManager", "zip",
"ImageFolderDatasetV2", "MnistDataset",
"MindDataset", "GeneratorDataset", "TFRecordDataset",
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset",
"VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler",
"VOCDataset", "CocoDataset", "TextFileDataset", "BuildVocabDataset", "Schema", "Schema",
"DistributedSampler", "PKSampler",
"RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"]

+ 92
- 3
mindspore/dataset/engine/datasets.py View File

@@ -42,8 +42,8 @@ from .iterators import DictIterator, TupleIterator
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
check_rename, check_numpyslicesdataset, \
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset,\
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat,\
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
check_split
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist

@@ -824,6 +824,29 @@ class Dataset:

return ProjectDataset(self, columns)

def build_vocab(self, vocab, columns, freq_range, top_k):
"""
Build a vocab from a dataset. This would collect all the unique words in a dataset and return a vocab
which contains top_k most frequent words (if top_k is specified)
This function is not meant to be called directly by user. To build vocab, please use the function
text.Vocab.from_dataset()

Args:
vocab(Vocab): vocab object
columns(str or list, optional): column names to get words from. It can be a list of column names.
(Default is None where all columns will be used. If any column isn't string type, will return error)
freq_range(tuple, optional): A tuple of integers (min_frequency, max_frequency). Words within the frequency
range would be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency
can be None, which corresponds to 0/total_words separately (default is None, all words are included)
top_k(int, optional): top_k > 0. Number of words to be built into vocab. top_k most frequent words are
taken. top_k is taken after freq_range. If not enough top_k, all words will be taken. (default is None
all words are included)

Returns:
BuildVocabDataset
"""
return BuildVocabDataset(self, vocab, columns, freq_range, top_k)

def apply(self, apply_func):
"""
Apply a function in this dataset.
@@ -1483,6 +1506,7 @@ class BatchDataset(DatasetOp):
for input_dataset in dataset.input:
BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)


class BatchInfo(CBatchInfo):
"""
The information object associates with the current batch of tensors.
@@ -1506,6 +1530,7 @@ class BatchInfo(CBatchInfo):
"""
return


class BlockReleasePair:
"""
The blocking condition class used by SyncWaitDataset.
@@ -1514,6 +1539,7 @@ class BlockReleasePair:
init_release_rows (int): Number of lines to allow through the pipeline.
callback (function): The callback funciton that will be called when release is called.
"""

def __init__(self, init_release_rows, callback=None):
if isinstance(init_release_rows, int) and init_release_rows <= 0:
raise ValueError("release_rows need to be greater than 0.")
@@ -1696,6 +1722,7 @@ class _PythonCallable:
"""
Internal python function wrapper for multiprocessing pyfunc.
"""

def __init__(self, py_callable, idx, pool=None):
# Original python callable from user.
self.py_callable = py_callable
@@ -2593,7 +2620,7 @@ class MindDataset(SourceDataset):

if sampler is not None:
if isinstance(sampler, samplers.SubsetRandomSampler) is False and \
isinstance(sampler, samplers.PKSampler) is False:
isinstance(sampler, samplers.PKSampler) is False:
raise ValueError("the sampler is not supported yet.")

# sampler exclusive
@@ -2859,6 +2886,7 @@ class _GeneratorWorker(multiprocessing.Process):
"""
Worker process for multiprocess Generator.
"""

def __init__(self, dataset, eoe):
self.idx_queue = multiprocessing.Queue(16)
self.res_queue = multiprocessing.Queue(16)
@@ -3686,6 +3714,7 @@ class RandomDataset(SourceDataset):
def is_sharded(self):
return False


class Schema:
"""
Class to represent a schema of dataset.
@@ -4383,6 +4412,7 @@ class _NumpySlicesDataset:
"""
Mainly for dealing with several kinds of format of python data, and return one row each time.
"""

def __init__(self, data, column_list=None):
self.column_list = None
# Convert dict data into tuple
@@ -4525,6 +4555,7 @@ class NumpySlicesDataset(GeneratorDataset):
>>> df = pd.read_csv("file.csv")
>>> dataset4 = ds.NumpySlicesDataset(dict(df), shuffle=False)
"""

@check_numpyslicesdataset
def __init__(self, data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None,
sampler=None, num_shards=None, shard_id=None):
@@ -4532,3 +4563,61 @@ class NumpySlicesDataset(GeneratorDataset):
super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples,
num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler,
num_shards=num_shards, shard_id=shard_id)


class BuildVocabDataset(DatasetOp):
"""
Build a vocab from a dataset. This would collect all the unique words in a dataset and return a vocab
which contains top_k most frequent words (if top_k is specified)
This function is not meant to be called directly by user. To build vocab, please use the function
text.Vocab.from_dataset()

Args:
vocab(Vocab): vocab object
columns(str or list, optional): column names to get words from. It can be a list of column names.
(Default is None where all columns will be used. If any column isn't string type, will return error)
freq_range(tuple, optional): A tuple of integers (min_frequency, max_frequency). Words within the frequency
range would be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency
can be None, which corresponds to 0/total_words separately (default is None, all words are included)
top_k(int, optional): top_k > 0. Number of words to be built into vocab. top_k most frequent words are
taken. top_k is taken after freq_range. If not enough top_k, all words will be taken. (default is None
all words are included)

Returns:
BuildVocabDataset
"""

def __init__(self, input_dataset, vocab, columns, freq_range, top_k, prefetch_size=None):
super().__init__()
self.columns = columns
self.input.append(input_dataset)
self.prefetch_size = prefetch_size
self.vocab = vocab
self.freq_range = freq_range
self.top_k = top_k
input_dataset.output.append(self)

def get_args(self):
args = super().get_args()
args["columns"] = self.columns
args["vocab"] = self.vocab
args["freq_range"] = self.freq_range
args["prefetch_size"] = self.prefetch_size
args["top_k"] = self.top_k
return args

def __deepcopy__(self, memodict):
if id(self) in memodict:
return memodict[id(self)]
cls = self.__class__
new_op = cls.__new__(cls)
memodict[id(self)] = new_op
new_op.input = copy.deepcopy(self.input, memodict)
new_op.columns = copy.deepcopy(self.columns, memodict)
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
new_op.prefetch_size = copy.deepcopy(self.prefetch_size, memodict)
new_op.output = copy.deepcopy(self.output, memodict)
new_op.freq_range = copy.deepcopy(self.freq_range, memodict)
new_op.top_k = copy.deepcopy(self.top_k, memodict)
new_op.vocab = self.vocab
return new_op

+ 2
- 0
mindspore/dataset/engine/iterators.py View File

@@ -177,6 +177,8 @@ class Iterator:
op_type = OpName.RANDOMDATA
elif isinstance(dataset, de.TextFileDataset):
op_type = OpName.TEXTFILE
elif isinstance(dataset, de.BuildVocabDataset):
op_type = OpName.BUILDVOCAB
else:
raise ValueError("Unsupported DatasetOp")



+ 27
- 4
mindspore/dataset/text/utils.py View File

@@ -16,20 +16,43 @@ Some basic function for nlp
"""
from enum import IntEnum

import copy
import numpy as np
import mindspore._c_dataengine as cde

from .validators import check_from_file, check_from_list, check_from_dict
from .validators import check_from_file, check_from_list, check_from_dict, check_from_dataset


class Vocab(cde.Vocab):
"""
Vocab object that is used for lookup word
Args:
"""

def __init__(self):
pass
@classmethod
@check_from_dataset
def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None):
"""
Build a vocab from a dataset. This would collect all the unique words in a dataset and return a vocab
which contains top_k most frequent words (if top_k is specified)
Args:
dataset(Dataset): dataset to build vocab from.
columns(str or list, optional): column names to get words from. It can be a list of column names.
(Default is None where all columns will be used. If any column isn't string type, will return error)
freq_range(tuple, optional): A tuple of integers (min_frequency, max_frequency). Words within the frequency
range would be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency
can be None, which corresponds to 0/total_words separately (default is None, all words are included)
top_k(int, optional): top_k > 0. Number of words to be built into vocab. top_k most frequent words are
taken. top_k is taken after freq_range. If not enough top_k, all words will be taken. (default is None
all words are included)
return:
text.Vocab: vocab object built from dataset.
"""
vocab = Vocab()
root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k)
for d in root.create_dict_iterator():
if d is not None:
raise ValueError("from_dataset should receive data other than None")
return vocab

@classmethod
@check_from_list


+ 56
- 0
mindspore/dataset/text/validators.py View File

@@ -186,6 +186,62 @@ def check_jieba_add_dict(method):
return new_method


def check_from_dataset(method):
"""A wrapper that wrap a parameter checker to the original function(crop operation)."""

# def from_dataset(cls, dataset, columns, freq_range=None, top_k=None):
@wraps(method)
def new_method(self, *args, **kwargs):
dataset, columns, freq_range, top_k = (list(args) + 4 * [None])[:4]
if "dataset" in kwargs:
dataset = kwargs.get("dataset")
if "columns" in kwargs:
columns = kwargs.get("columns")
if "freq_range" in kwargs:
freq_range = kwargs.get("freq_range")
if "top_k" in kwargs:
top_k = kwargs.get("top_k")

if columns is None:
columns = []

if not isinstance(columns, list):
columns = [columns]

for column in columns:
if not isinstance(column, str):
raise ValueError("columns need to be a list of strings")

if freq_range is None:
freq_range = (None, None)

if not isinstance(freq_range, tuple) or len(freq_range) != 2:
raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None")

for num in freq_range:
if num is not None and (not isinstance(num, int)):
raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None")

if isinstance(freq_range[0], int) and isinstance(freq_range[1], int):
if freq_range[0] > freq_range[1] or freq_range[0] < 0:
raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)")

if top_k is not None and (not isinstance(top_k, int)):
raise ValueError("top_k needs to be a positive integer")

if isinstance(top_k, int) and top_k <= 0:
raise ValueError("top_k needs to be a positive integer")

kwargs["dataset"] = dataset
kwargs["columns"] = columns
kwargs["freq_range"] = freq_range
kwargs["top_k"] = top_k

return method(self, **kwargs)

return new_method


def check_ngram(method):
"""A wrapper that wrap a parameter checker to the original function(crop operation)."""



+ 112
- 0
tests/ut/python/dataset/test_from_dataset.py View File

@@ -0,0 +1,112 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing from_dataset in mindspore.dataset
"""
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.text as text


def test_demo_basic_from_dataset():
""" this is a tutorial on how from_dataset should be used in a normal use case"""
data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False)
vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None)
data = data.map(input_columns=["text"], operations=text.Lookup(vocab))
res = []
for d in data.create_dict_iterator():
res.append(d["text"].item())
assert res == [4, 5, 3, 6, 7, 2]


def test_demo_basic_from_dataset_with_tokenizer():
""" this is a tutorial on how from_dataset should be used in a normal use case with tokenizer"""
data = ds.TextFileDataset("../data/dataset/testTokenizerData/1.txt", shuffle=False)
data = data.map(input_columns=["text"], operations=text.UnicodeCharTokenizer())
vocab = text.Vocab.from_dataset(data, None, freq_range=None, top_k=None)
data = data.map(input_columns=["text"], operations=text.Lookup(vocab))
res = []
for d in data.create_dict_iterator():
res.append(list(d["text"]))
assert res == [[13, 3, 7, 14, 9, 17, 3, 2, 19, 9, 2, 11, 3, 4, 16, 4, 8, 6, 5], [21, 20, 10, 25, 23, 26],
[24, 22, 10, 12, 8, 6, 7, 4, 18, 15, 5], [2, 2]]


def test_from_dataset():
""" test build vocab with generator dataset """

def gen_corpus():
# key: word, value: number of occurrences, reason for using letters is so their order is apparent
corpus = {"Z": 4, "Y": 4, "X": 4, "W": 3, "U": 3, "V": 2, "T": 1}
for k, v in corpus.items():
yield (np.array([k] * v, dtype='S'),)

def test_config(freq_range, top_k):
corpus_dataset = ds.GeneratorDataset(gen_corpus, column_names=["text"])
vocab = text.Vocab.from_dataset(corpus_dataset, None, freq_range, top_k)
corpus_dataset = corpus_dataset.map(input_columns="text", operations=text.Lookup(vocab))
res = []
for d in corpus_dataset.create_dict_iterator():
res.append(list(d["text"]))
return res

# take words whose frequency is with in [3,4] order them alphabetically for words with the same frequency
test1_res = test_config(freq_range=(3, 4), top_k=4)
assert test1_res == [[4, 4, 4, 4], [3, 3, 3, 3], [2, 2, 2, 2], [1, 1, 1], [5, 5, 5], [1, 1], [1]], str(test1_res)

# test words with frequency range [2,inf], only the last word will be filtered out
test2_res = test_config((2, None), None)
assert test2_res == [[4, 4, 4, 4], [3, 3, 3, 3], [2, 2, 2, 2], [6, 6, 6], [5, 5, 5], [7, 7], [1]], str(test2_res)

# test filter only by top_k
test3_res = test_config(None, 4)
assert test3_res == [[4, 4, 4, 4], [3, 3, 3, 3], [2, 2, 2, 2], [1, 1, 1], [5, 5, 5], [1, 1], [1]], str(test3_res)

# test filtering out the most frequent
test4_res = test_config((None, 3), 100)
assert test4_res == [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [3, 3, 3], [2, 2, 2], [4, 4], [5]], str(test4_res)

# test top_k == 1
test5_res = test_config(None, 1)
assert test5_res == [[1, 1, 1, 1], [1, 1, 1, 1], [2, 2, 2, 2], [1, 1, 1], [1, 1, 1], [1, 1], [1]], str(test5_res)

# test min_frequency == max_frequency
test6_res = test_config((4, 4), None)
assert test6_res == [[4, 4, 4, 4], [3, 3, 3, 3], [2, 2, 2, 2], [1, 1, 1], [1, 1, 1], [1, 1], [1]], str(test6_res)


def test_from_dataset_exceptions():
""" test various exceptions during that are checked in validator """

def test_config(columns, freq_range, top_k, s):
try:
data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False)
vocab = text.Vocab.from_dataset(data, columns, freq_range, top_k)
assert isinstance(vocab.text.Vocab)
except ValueError as e:
assert s in str(e), str(e)

test_config("text", (), 1, "freq_range needs to be either None or a tuple of 2 integers")
test_config("text", (2, 3), 1.2345, "top_k needs to be a positive integer")
test_config(23, (2, 3), 1.2345, "columns need to be a list of strings")
test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b")
test_config("text", (2, 3), 0, "top_k needs to be a positive integer")
test_config([123], (2, 3), 0, "columns need to be a list of strings")

if __name__ == '__main__':
test_demo_basic_from_dataset()
test_from_dataset()
test_from_dataset_exceptions()
test_demo_basic_from_dataset_with_tokenizer()

+ 1
- 1
tests/ut/python/dataset/test_ngram_op.py View File

@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
"""
Testing NgramOP in DE
Testing Ngram in mindspore.dataset
"""
import mindspore.dataset as ds
import mindspore.dataset.text as nlp


Loading…
Cancel
Save