| @@ -61,7 +61,7 @@ set(submodules | |||||
| $<TARGET_OBJECTS:kernels> | $<TARGET_OBJECTS:kernels> | ||||
| $<TARGET_OBJECTS:kernels-image> | $<TARGET_OBJECTS:kernels-image> | ||||
| $<TARGET_OBJECTS:kernels-data> | $<TARGET_OBJECTS:kernels-data> | ||||
| $<TARGET_OBJECTS:kernels-nlp> | |||||
| $<TARGET_OBJECTS:kernels-text> | |||||
| $<TARGET_OBJECTS:APItoPython> | $<TARGET_OBJECTS:APItoPython> | ||||
| $<TARGET_OBJECTS:engine-datasetops-source> | $<TARGET_OBJECTS:engine-datasetops-source> | ||||
| $<TARGET_OBJECTS:engine-datasetops-source-sampler> | $<TARGET_OBJECTS:engine-datasetops-source-sampler> | ||||
| @@ -39,6 +39,7 @@ | |||||
| #include "dataset/kernels/image/uniform_aug_op.h" | #include "dataset/kernels/image/uniform_aug_op.h" | ||||
| #include "dataset/kernels/data/type_cast_op.h" | #include "dataset/kernels/data/type_cast_op.h" | ||||
| #include "dataset/kernels/text/jieba_tokenizer_op.h" | #include "dataset/kernels/text/jieba_tokenizer_op.h" | ||||
| #include "dataset/kernels/text/unicode_char_tokenizer_op.h" | |||||
| #include "dataset/engine/datasetops/source/cifar_op.h" | #include "dataset/engine/datasetops/source/cifar_op.h" | ||||
| #include "dataset/engine/datasetops/source/image_folder_op.h" | #include "dataset/engine/datasetops/source/image_folder_op.h" | ||||
| #include "dataset/engine/datasetops/source/io_block.h" | #include "dataset/engine/datasetops/source/io_block.h" | ||||
| @@ -407,12 +408,16 @@ void bindTensorOps4(py::module *m) { | |||||
| py::arg("fillR") = PadOp::kDefFillR, py::arg("fillG") = PadOp::kDefFillG, py::arg("fillB") = PadOp::kDefFillB); | py::arg("fillR") = PadOp::kDefFillR, py::arg("fillG") = PadOp::kDefFillG, py::arg("fillB") = PadOp::kDefFillB); | ||||
| } | } | ||||
| void bindTensorOps6(py::module *m) { | |||||
| void bindTensorOps5(py::module *m) { | |||||
| (void)py::class_<JiebaTokenizerOp, TensorOp, std::shared_ptr<JiebaTokenizerOp>>(*m, "JiebaTokenizerOp", "") | (void)py::class_<JiebaTokenizerOp, TensorOp, std::shared_ptr<JiebaTokenizerOp>>(*m, "JiebaTokenizerOp", "") | ||||
| .def(py::init<const std::string, std::string, JiebaMode>(), py::arg("hmm_path"), py::arg("mp_path"), | .def(py::init<const std::string, std::string, JiebaMode>(), py::arg("hmm_path"), py::arg("mp_path"), | ||||
| py::arg("mode") = JiebaMode::kMix) | py::arg("mode") = JiebaMode::kMix) | ||||
| .def("add_word", | .def("add_word", | ||||
| [](JiebaTokenizerOp &self, const std::string word, int freq) { THROW_IF_ERROR(self.AddWord(word, freq)); }); | [](JiebaTokenizerOp &self, const std::string word, int freq) { THROW_IF_ERROR(self.AddWord(word, freq)); }); | ||||
| (void)py::class_<UnicodeCharTokenizerOp, TensorOp, std::shared_ptr<UnicodeCharTokenizerOp>>( | |||||
| *m, "UnicodeCharTokenizerOp", "Tokenize a scalar tensor of UTF-8 string to Unicode characters.") | |||||
| .def(py::init<>()); | |||||
| } | } | ||||
| void bindSamplerOps(py::module *m) { | void bindSamplerOps(py::module *m) { | ||||
| @@ -534,7 +539,7 @@ PYBIND11_MODULE(_c_dataengine, m) { | |||||
| bindTensorOps2(&m); | bindTensorOps2(&m); | ||||
| bindTensorOps3(&m); | bindTensorOps3(&m); | ||||
| bindTensorOps4(&m); | bindTensorOps4(&m); | ||||
| bindTensorOps6(&m); | |||||
| bindTensorOps5(&m); | |||||
| bindSamplerOps(&m); | bindSamplerOps(&m); | ||||
| bindDatasetOps(&m); | bindDatasetOps(&m); | ||||
| bindInfoObjects(&m); | bindInfoObjects(&m); | ||||
| @@ -1,5 +1,6 @@ | |||||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | ||||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | ||||
| add_library(kernels-nlp OBJECT | |||||
| add_library(kernels-text OBJECT | |||||
| jieba_tokenizer_op.cc | jieba_tokenizer_op.cc | ||||
| unicode_char_tokenizer_op.cc | |||||
| ) | ) | ||||
| @@ -13,8 +13,8 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef DATASET_ENGINE_NLP_JIEBA_OP_H_ | |||||
| #define DATASET_ENGINE_NLP_JIEBA_OP_H_ | |||||
| #ifndef DATASET_ENGINE_TEXT_JIEBA_OP_H_ | |||||
| #define DATASET_ENGINE_TEXT_JIEBA_OP_H_ | |||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| @@ -61,4 +61,4 @@ class JiebaTokenizerOp : public TensorOp { | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // DATASET_ENGINE_NLP_JIEBA_OP_H_ | |||||
| #endif // DATASET_ENGINE_TEXT_JIEBA_OP_H_ | |||||
| @@ -0,0 +1,52 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "dataset/kernels/text/unicode_char_tokenizer_op.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <string_view> | |||||
| #include <vector> | |||||
| #include "cppjieba/Unicode.hpp" | |||||
| using cppjieba::DecodeRunesInString; | |||||
| using cppjieba::RuneStrArray; | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| Status UnicodeCharTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||||
| if (input->Rank() != 0 || input->type() != DataType::DE_STRING) { | |||||
| RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); | |||||
| } | |||||
| std::string_view str; | |||||
| RETURN_IF_NOT_OK(input->GetItemAt(&str, {})); | |||||
| RuneStrArray runes; | |||||
| if (!DecodeRunesInString(str.data(), str.size(), runes)) { | |||||
| RETURN_STATUS_UNEXPECTED("Decode utf8 string failed."); | |||||
| } | |||||
| std::vector<std::string> splits(runes.size()); | |||||
| for (size_t i = 0; i < runes.size(); i++) { | |||||
| splits[i] = str.substr(runes[i].offset, runes[i].len); | |||||
| } | |||||
| if (splits.empty()) { | |||||
| splits.emplace_back(""); | |||||
| } | |||||
| *output = std::make_shared<Tensor>(splits, TensorShape({(dsize_t)splits.size()})); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,40 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DATASET_KERNELS_TEXT_UNICODE_CHAR_TOKENIZER_OP_H_ | |||||
| #define DATASET_KERNELS_TEXT_UNICODE_CHAR_TOKENIZER_OP_H_ | |||||
| #include <memory> | |||||
| #include "dataset/core/tensor.h" | |||||
| #include "dataset/kernels/tensor_op.h" | |||||
| #include "dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class UnicodeCharTokenizerOp : public TensorOp { | |||||
| public: | |||||
| UnicodeCharTokenizerOp() {} | |||||
| ~UnicodeCharTokenizerOp() override = default; | |||||
| void Print(std::ostream &out) const override { out << "UnicodeCharTokenizerOp"; } | |||||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // DATASET_KERNELS_TEXT_UNICODE_CHAR_TOKENIZER_OP_H_ | |||||
| @@ -284,10 +284,10 @@ class Dataset: | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| >>> import mindspore.dataset.transforms.nlp.utils as nlp | |||||
| >>> import mindspore.dataset.transforms.text.utils as text | |||||
| >>> # declare a function which returns a Dataset object | >>> # declare a function which returns a Dataset object | ||||
| >>> def flat_map_func(x): | >>> def flat_map_func(x): | ||||
| >>> data_dir = nlp.as_text(x[0]) | |||||
| >>> data_dir = text.as_text(x[0]) | |||||
| >>> d = ds.ImageFolderDatasetV2(data_dir) | >>> d = ds.ImageFolderDatasetV2(data_dir) | ||||
| >>> return d | >>> return d | ||||
| >>> # data is a Dataset object | >>> # data is a Dataset object | ||||
| @@ -18,3 +18,4 @@ image augmentation module which is developed with c++ opencv. Py_transforms | |||||
| provide more kinds of image augmentations which is developed with python PIL. | provide more kinds of image augmentations which is developed with python PIL. | ||||
| """ | """ | ||||
| from .utils import as_text, JiebaMode | from .utils import as_text, JiebaMode | ||||
| from . import c_transforms | |||||
| @@ -123,3 +123,9 @@ class JiebaTokenizer(cde.JiebaTokenizerOp): | |||||
| if not os.path.exists(model_path): | if not os.path.exists(model_path): | ||||
| raise ValueError( | raise ValueError( | ||||
| " jieba mode file {} is not exist".format(model_path)) | " jieba mode file {} is not exist".format(model_path)) | ||||
| class UnicodeCharTokenizer(cde.UnicodeCharTokenizerOp): | |||||
| """ | |||||
| Tokenize a scalar tensor of UTF-8 string to Unicode characters. | |||||
| """ | |||||
| @@ -33,9 +33,7 @@ def as_text(array, encoding='utf8'): | |||||
| if not isinstance(array, np.ndarray): | if not isinstance(array, np.ndarray): | ||||
| raise ValueError('input should be a numpy array') | raise ValueError('input should be a numpy array') | ||||
| def decode(x): | |||||
| return x.decode(encoding) | |||||
| decode = np.vectorize(decode) | |||||
| decode = np.vectorize(lambda x: x.decode(encoding)) | |||||
| return decode(array) | return decode(array) | ||||
| @@ -69,6 +69,7 @@ SET(DE_UT_SRCS | |||||
| filter_op_test.cc | filter_op_test.cc | ||||
| concat_op_test.cc | concat_op_test.cc | ||||
| jieba_tokenizer_op_test.cc | jieba_tokenizer_op_test.cc | ||||
| tokenizer_op_test.cc | |||||
| ) | ) | ||||
| add_executable(de_ut_tests ${DE_UT_SRCS}) | add_executable(de_ut_tests ${DE_UT_SRCS}) | ||||
| @@ -0,0 +1,107 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <string_view> | |||||
| #include "common/common.h" | |||||
| #include "dataset/kernels/text/unicode_char_tokenizer_op.h" | |||||
| #include "gtest/gtest.h" | |||||
| #include "utils/log_adapter.h" | |||||
| using namespace mindspore::dataset; | |||||
| class MindDataTestTokenizerOp : public UT::Common { | |||||
| public: | |||||
| void CheckEqual(const std::shared_ptr<Tensor> &o, | |||||
| const std::vector<dsize_t> &index, | |||||
| const std::string &expect) { | |||||
| std::string_view str; | |||||
| Status s = o->GetItemAt(&str, index); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| EXPECT_EQ(str, expect); | |||||
| } | |||||
| }; | |||||
| TEST_F(MindDataTestTokenizerOp, TestUnicodeCharTokenizerOp) { | |||||
| MS_LOG(INFO) << "Doing TestUnicodeCharTokenizerOp."; | |||||
| std::unique_ptr<UnicodeCharTokenizerOp> op(new UnicodeCharTokenizerOp()); | |||||
| std::shared_ptr<Tensor> input = std::make_shared<Tensor>("Hello World!"); | |||||
| std::shared_ptr<Tensor> output; | |||||
| Status s = op->Compute(input, &output); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| EXPECT_EQ(output->Size(), 12); | |||||
| EXPECT_EQ(output->Rank(), 1); | |||||
| MS_LOG(INFO) << "Out tensor1: " << output->ToString(); | |||||
| CheckEqual(output, {0}, "H"); | |||||
| CheckEqual(output, {1}, "e"); | |||||
| CheckEqual(output, {2}, "l"); | |||||
| CheckEqual(output, {3}, "l"); | |||||
| CheckEqual(output, {4}, "o"); | |||||
| CheckEqual(output, {5}, " "); | |||||
| CheckEqual(output, {6}, "W"); | |||||
| CheckEqual(output, {7}, "o"); | |||||
| CheckEqual(output, {8}, "r"); | |||||
| CheckEqual(output, {9}, "l"); | |||||
| CheckEqual(output, {10}, "d"); | |||||
| CheckEqual(output, {11}, "!"); | |||||
| input = std::make_shared<Tensor>("中国 你好!"); | |||||
| s = op->Compute(input, &output); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| EXPECT_EQ(output->Size(), 6); | |||||
| EXPECT_EQ(output->Rank(), 1); | |||||
| MS_LOG(INFO) << "Out tensor2: " << output->ToString(); | |||||
| CheckEqual(output, {0}, "中"); | |||||
| CheckEqual(output, {1}, "国"); | |||||
| CheckEqual(output, {2}, " "); | |||||
| CheckEqual(output, {3}, "你"); | |||||
| CheckEqual(output, {4}, "好"); | |||||
| CheckEqual(output, {5}, "!"); | |||||
| input = std::make_shared<Tensor>("中"); | |||||
| s = op->Compute(input, &output); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| EXPECT_EQ(output->Size(), 1); | |||||
| EXPECT_EQ(output->Rank(), 1); | |||||
| MS_LOG(INFO) << "Out tensor3: " << output->ToString(); | |||||
| CheckEqual(output, {0}, "中"); | |||||
| input = std::make_shared<Tensor>("H"); | |||||
| s = op->Compute(input, &output); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| EXPECT_EQ(output->Size(), 1); | |||||
| EXPECT_EQ(output->Rank(), 1); | |||||
| MS_LOG(INFO) << "Out tensor4: " << output->ToString(); | |||||
| CheckEqual(output, {0}, "H"); | |||||
| input = std::make_shared<Tensor>(" "); | |||||
| s = op->Compute(input, &output); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| EXPECT_EQ(output->Size(), 2); | |||||
| EXPECT_EQ(output->Rank(), 1); | |||||
| MS_LOG(INFO) << "Out tensor5: " << output->ToString(); | |||||
| CheckEqual(output, {0}, " "); | |||||
| CheckEqual(output, {1}, " "); | |||||
| input = std::make_shared<Tensor>(""); | |||||
| s = op->Compute(input, &output); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| EXPECT_EQ(output->Size(), 1); | |||||
| EXPECT_EQ(output->Rank(), 1); | |||||
| MS_LOG(INFO) << "Out tensor6: " << output->ToString(); | |||||
| CheckEqual(output, {0}, ""); | |||||
| } | |||||
| @@ -0,0 +1,4 @@ | |||||
| Welcome to Beijing! | |||||
| 北京欢迎您! | |||||
| 我喜欢English! | |||||
| @@ -0,0 +1,53 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """ | |||||
| Testing UnicodeCharTokenizer op in DE | |||||
| """ | |||||
| import mindspore.dataset as ds | |||||
| from mindspore import log as logger | |||||
| import mindspore.dataset.transforms.text.c_transforms as nlp | |||||
| import mindspore.dataset.transforms.text.utils as nlp_util | |||||
| DATA_FILE = "../data/dataset/testTokenizerData/1.txt" | |||||
| def split_by_unicode_char(input_strs): | |||||
| """ | |||||
| Split utf-8 strings to unicode characters | |||||
| """ | |||||
| out = [] | |||||
| for s in input_strs: | |||||
| out.append([c for c in s]) | |||||
| return out | |||||
| def test_unicode_char_tokenizer(): | |||||
| """ | |||||
| Test UnicodeCharTokenizer | |||||
| """ | |||||
| input_strs = ("Welcome to Beijing!", "北京欢迎您!", "我喜欢English!", " ") | |||||
| dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) | |||||
| tokenizer = nlp.UnicodeCharTokenizer() | |||||
| dataset = dataset.map(operations=tokenizer) | |||||
| tokens = [] | |||||
| for i in dataset.create_dict_iterator(): | |||||
| text = nlp_util.as_text(i['text']).tolist() | |||||
| tokens.append(text) | |||||
| logger.info("The out tokens is : {}".format(tokens)) | |||||
| assert split_by_unicode_char(input_strs) == tokens | |||||
| if __name__ == '__main__': | |||||
| test_unicode_char_tokenizer() | |||||