| @@ -61,7 +61,7 @@ set(submodules | |||
| $<TARGET_OBJECTS:kernels> | |||
| $<TARGET_OBJECTS:kernels-image> | |||
| $<TARGET_OBJECTS:kernels-data> | |||
| $<TARGET_OBJECTS:kernels-nlp> | |||
| $<TARGET_OBJECTS:kernels-text> | |||
| $<TARGET_OBJECTS:APItoPython> | |||
| $<TARGET_OBJECTS:engine-datasetops-source> | |||
| $<TARGET_OBJECTS:engine-datasetops-source-sampler> | |||
| @@ -39,6 +39,7 @@ | |||
| #include "dataset/kernels/image/uniform_aug_op.h" | |||
| #include "dataset/kernels/data/type_cast_op.h" | |||
| #include "dataset/kernels/text/jieba_tokenizer_op.h" | |||
| #include "dataset/kernels/text/unicode_char_tokenizer_op.h" | |||
| #include "dataset/engine/datasetops/source/cifar_op.h" | |||
| #include "dataset/engine/datasetops/source/image_folder_op.h" | |||
| #include "dataset/engine/datasetops/source/io_block.h" | |||
| @@ -407,12 +408,16 @@ void bindTensorOps4(py::module *m) { | |||
| py::arg("fillR") = PadOp::kDefFillR, py::arg("fillG") = PadOp::kDefFillG, py::arg("fillB") = PadOp::kDefFillB); | |||
| } | |||
| void bindTensorOps6(py::module *m) { | |||
| void bindTensorOps5(py::module *m) { | |||
| (void)py::class_<JiebaTokenizerOp, TensorOp, std::shared_ptr<JiebaTokenizerOp>>(*m, "JiebaTokenizerOp", "") | |||
| .def(py::init<const std::string, std::string, JiebaMode>(), py::arg("hmm_path"), py::arg("mp_path"), | |||
| py::arg("mode") = JiebaMode::kMix) | |||
| .def("add_word", | |||
| [](JiebaTokenizerOp &self, const std::string word, int freq) { THROW_IF_ERROR(self.AddWord(word, freq)); }); | |||
| (void)py::class_<UnicodeCharTokenizerOp, TensorOp, std::shared_ptr<UnicodeCharTokenizerOp>>( | |||
| *m, "UnicodeCharTokenizerOp", "Tokenize a scalar tensor of UTF-8 string to Unicode characters.") | |||
| .def(py::init<>()); | |||
| } | |||
| void bindSamplerOps(py::module *m) { | |||
| @@ -534,7 +539,7 @@ PYBIND11_MODULE(_c_dataengine, m) { | |||
| bindTensorOps2(&m); | |||
| bindTensorOps3(&m); | |||
| bindTensorOps4(&m); | |||
| bindTensorOps6(&m); | |||
| bindTensorOps5(&m); | |||
| bindSamplerOps(&m); | |||
| bindDatasetOps(&m); | |||
| bindInfoObjects(&m); | |||
| @@ -1,5 +1,6 @@ | |||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| add_library(kernels-nlp OBJECT | |||
| add_library(kernels-text OBJECT | |||
| jieba_tokenizer_op.cc | |||
| unicode_char_tokenizer_op.cc | |||
| ) | |||
| @@ -13,8 +13,8 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_NLP_JIEBA_OP_H_ | |||
| #define DATASET_ENGINE_NLP_JIEBA_OP_H_ | |||
| #ifndef DATASET_ENGINE_TEXT_JIEBA_OP_H_ | |||
| #define DATASET_ENGINE_TEXT_JIEBA_OP_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| @@ -61,4 +61,4 @@ class JiebaTokenizerOp : public TensorOp { | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_NLP_JIEBA_OP_H_ | |||
| #endif // DATASET_ENGINE_TEXT_JIEBA_OP_H_ | |||
| @@ -0,0 +1,52 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/kernels/text/unicode_char_tokenizer_op.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include <string_view> | |||
| #include <vector> | |||
| #include "cppjieba/Unicode.hpp" | |||
| using cppjieba::DecodeRunesInString; | |||
| using cppjieba::RuneStrArray; | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status UnicodeCharTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||
| if (input->Rank() != 0 || input->type() != DataType::DE_STRING) { | |||
| RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); | |||
| } | |||
| std::string_view str; | |||
| RETURN_IF_NOT_OK(input->GetItemAt(&str, {})); | |||
| RuneStrArray runes; | |||
| if (!DecodeRunesInString(str.data(), str.size(), runes)) { | |||
| RETURN_STATUS_UNEXPECTED("Decode utf8 string failed."); | |||
| } | |||
| std::vector<std::string> splits(runes.size()); | |||
| for (size_t i = 0; i < runes.size(); i++) { | |||
| splits[i] = str.substr(runes[i].offset, runes[i].len); | |||
| } | |||
| if (splits.empty()) { | |||
| splits.emplace_back(""); | |||
| } | |||
| *output = std::make_shared<Tensor>(splits, TensorShape({(dsize_t)splits.size()})); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_KERNELS_TEXT_UNICODE_CHAR_TOKENIZER_OP_H_ | |||
| #define DATASET_KERNELS_TEXT_UNICODE_CHAR_TOKENIZER_OP_H_ | |||
| #include <memory> | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| #include "dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class UnicodeCharTokenizerOp : public TensorOp { | |||
| public: | |||
| UnicodeCharTokenizerOp() {} | |||
| ~UnicodeCharTokenizerOp() override = default; | |||
| void Print(std::ostream &out) const override { out << "UnicodeCharTokenizerOp"; } | |||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_KERNELS_TEXT_UNICODE_CHAR_TOKENIZER_OP_H_ | |||
| @@ -284,10 +284,10 @@ class Dataset: | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> import mindspore.dataset.transforms.nlp.utils as nlp | |||
| >>> import mindspore.dataset.transforms.text.utils as text | |||
| >>> # declare a function which returns a Dataset object | |||
| >>> def flat_map_func(x): | |||
| >>> data_dir = nlp.as_text(x[0]) | |||
| >>> data_dir = text.as_text(x[0]) | |||
| >>> d = ds.ImageFolderDatasetV2(data_dir) | |||
| >>> return d | |||
| >>> # data is a Dataset object | |||
| @@ -18,3 +18,4 @@ image augmentation module which is developed with c++ opencv. Py_transforms | |||
| provide more kinds of image augmentations which is developed with python PIL. | |||
| """ | |||
| from .utils import as_text, JiebaMode | |||
| from . import c_transforms | |||
| @@ -123,3 +123,9 @@ class JiebaTokenizer(cde.JiebaTokenizerOp): | |||
| if not os.path.exists(model_path): | |||
| raise ValueError( | |||
| " jieba mode file {} is not exist".format(model_path)) | |||
| class UnicodeCharTokenizer(cde.UnicodeCharTokenizerOp): | |||
| """ | |||
| Tokenize a scalar tensor of UTF-8 string to Unicode characters. | |||
| """ | |||
| @@ -33,9 +33,7 @@ def as_text(array, encoding='utf8'): | |||
| if not isinstance(array, np.ndarray): | |||
| raise ValueError('input should be a numpy array') | |||
| def decode(x): | |||
| return x.decode(encoding) | |||
| decode = np.vectorize(decode) | |||
| decode = np.vectorize(lambda x: x.decode(encoding)) | |||
| return decode(array) | |||
| @@ -69,6 +69,7 @@ SET(DE_UT_SRCS | |||
| filter_op_test.cc | |||
| concat_op_test.cc | |||
| jieba_tokenizer_op_test.cc | |||
| tokenizer_op_test.cc | |||
| ) | |||
| add_executable(de_ut_tests ${DE_UT_SRCS}) | |||
| @@ -0,0 +1,107 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <string_view> | |||
| #include "common/common.h" | |||
| #include "dataset/kernels/text/unicode_char_tokenizer_op.h" | |||
| #include "gtest/gtest.h" | |||
| #include "utils/log_adapter.h" | |||
| using namespace mindspore::dataset; | |||
| class MindDataTestTokenizerOp : public UT::Common { | |||
| public: | |||
| void CheckEqual(const std::shared_ptr<Tensor> &o, | |||
| const std::vector<dsize_t> &index, | |||
| const std::string &expect) { | |||
| std::string_view str; | |||
| Status s = o->GetItemAt(&str, index); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_EQ(str, expect); | |||
| } | |||
| }; | |||
| TEST_F(MindDataTestTokenizerOp, TestUnicodeCharTokenizerOp) { | |||
| MS_LOG(INFO) << "Doing TestUnicodeCharTokenizerOp."; | |||
| std::unique_ptr<UnicodeCharTokenizerOp> op(new UnicodeCharTokenizerOp()); | |||
| std::shared_ptr<Tensor> input = std::make_shared<Tensor>("Hello World!"); | |||
| std::shared_ptr<Tensor> output; | |||
| Status s = op->Compute(input, &output); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_EQ(output->Size(), 12); | |||
| EXPECT_EQ(output->Rank(), 1); | |||
| MS_LOG(INFO) << "Out tensor1: " << output->ToString(); | |||
| CheckEqual(output, {0}, "H"); | |||
| CheckEqual(output, {1}, "e"); | |||
| CheckEqual(output, {2}, "l"); | |||
| CheckEqual(output, {3}, "l"); | |||
| CheckEqual(output, {4}, "o"); | |||
| CheckEqual(output, {5}, " "); | |||
| CheckEqual(output, {6}, "W"); | |||
| CheckEqual(output, {7}, "o"); | |||
| CheckEqual(output, {8}, "r"); | |||
| CheckEqual(output, {9}, "l"); | |||
| CheckEqual(output, {10}, "d"); | |||
| CheckEqual(output, {11}, "!"); | |||
| input = std::make_shared<Tensor>("中国 你好!"); | |||
| s = op->Compute(input, &output); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_EQ(output->Size(), 6); | |||
| EXPECT_EQ(output->Rank(), 1); | |||
| MS_LOG(INFO) << "Out tensor2: " << output->ToString(); | |||
| CheckEqual(output, {0}, "中"); | |||
| CheckEqual(output, {1}, "国"); | |||
| CheckEqual(output, {2}, " "); | |||
| CheckEqual(output, {3}, "你"); | |||
| CheckEqual(output, {4}, "好"); | |||
| CheckEqual(output, {5}, "!"); | |||
| input = std::make_shared<Tensor>("中"); | |||
| s = op->Compute(input, &output); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_EQ(output->Size(), 1); | |||
| EXPECT_EQ(output->Rank(), 1); | |||
| MS_LOG(INFO) << "Out tensor3: " << output->ToString(); | |||
| CheckEqual(output, {0}, "中"); | |||
| input = std::make_shared<Tensor>("H"); | |||
| s = op->Compute(input, &output); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_EQ(output->Size(), 1); | |||
| EXPECT_EQ(output->Rank(), 1); | |||
| MS_LOG(INFO) << "Out tensor4: " << output->ToString(); | |||
| CheckEqual(output, {0}, "H"); | |||
| input = std::make_shared<Tensor>(" "); | |||
| s = op->Compute(input, &output); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_EQ(output->Size(), 2); | |||
| EXPECT_EQ(output->Rank(), 1); | |||
| MS_LOG(INFO) << "Out tensor5: " << output->ToString(); | |||
| CheckEqual(output, {0}, " "); | |||
| CheckEqual(output, {1}, " "); | |||
| input = std::make_shared<Tensor>(""); | |||
| s = op->Compute(input, &output); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_EQ(output->Size(), 1); | |||
| EXPECT_EQ(output->Rank(), 1); | |||
| MS_LOG(INFO) << "Out tensor6: " << output->ToString(); | |||
| CheckEqual(output, {0}, ""); | |||
| } | |||
| @@ -0,0 +1,4 @@ | |||
| Welcome to Beijing! | |||
| 北京欢迎您! | |||
| 我喜欢English! | |||
| @@ -0,0 +1,53 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing UnicodeCharTokenizer op in DE | |||
| """ | |||
| import mindspore.dataset as ds | |||
| from mindspore import log as logger | |||
| import mindspore.dataset.transforms.text.c_transforms as nlp | |||
| import mindspore.dataset.transforms.text.utils as nlp_util | |||
| DATA_FILE = "../data/dataset/testTokenizerData/1.txt" | |||
| def split_by_unicode_char(input_strs): | |||
| """ | |||
| Split utf-8 strings to unicode characters | |||
| """ | |||
| out = [] | |||
| for s in input_strs: | |||
| out.append([c for c in s]) | |||
| return out | |||
| def test_unicode_char_tokenizer(): | |||
| """ | |||
| Test UnicodeCharTokenizer | |||
| """ | |||
| input_strs = ("Welcome to Beijing!", "北京欢迎您!", "我喜欢English!", " ") | |||
| dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) | |||
| tokenizer = nlp.UnicodeCharTokenizer() | |||
| dataset = dataset.map(operations=tokenizer) | |||
| tokens = [] | |||
| for i in dataset.create_dict_iterator(): | |||
| text = nlp_util.as_text(i['text']).tolist() | |||
| tokens.append(text) | |||
| logger.info("The out tokens is : {}".format(tokens)) | |||
| assert split_by_unicode_char(input_strs) == tokens | |||
| if __name__ == '__main__': | |||
| test_unicode_char_tokenizer() | |||