!38 Synchronize with latest Ascend software suite 17 Jun 2020 Merge pull request !38 from yanghaoran/mastertags/v0.5.0-beta
| @@ -22,12 +22,13 @@ from .utils import to_str, to_bytes, JiebaMode, Vocab, NormalizeForm | |||
| __all__ = [ | |||
| "Lookup", "JiebaTokenizer", "UnicodeCharTokenizer", "Ngram", | |||
| "to_str", "to_bytes", "JiebaMode", "Vocab", "WordpieceTokenizer", "TruncateSequencePair", "ToNumber" | |||
| "to_str", "to_bytes", "JiebaMode", "Vocab", "WordpieceTokenizer", "TruncateSequencePair", "ToNumber", | |||
| "PythonTokenizer" | |||
| ] | |||
| if platform.system().lower() != 'windows': | |||
| from .transforms import UnicodeScriptTokenizer, WhitespaceTokenizer, CaseFold, NormalizeUTF8, \ | |||
| RegexReplace, RegexTokenizer, BasicTokenizer, BertTokenizer | |||
| RegexReplace, RegexTokenizer, BasicTokenizer, BertTokenizer, PythonTokenizer | |||
| __all__.append(["UnicodeScriptTokenizer", "WhitespaceTokenizer", "CaseFold", "NormalizeUTF8", | |||
| "RegexReplace", "RegexTokenizer", "BasicTokenizer", "BertTokenizer", "NormalizeForm"]) | |||
| @@ -18,13 +18,14 @@ c transforms for all text related operators | |||
| import os | |||
| import re | |||
| import platform | |||
| import numpy as np | |||
| import mindspore._c_dataengine as cde | |||
| from .utils import JiebaMode, NormalizeForm | |||
| from .utils import JiebaMode, NormalizeForm, to_str | |||
| from .validators import check_lookup, check_jieba_add_dict, \ | |||
| check_jieba_add_word, check_jieba_init, check_ngram, check_pair_truncate, \ | |||
| check_to_number | |||
| check_to_number, check_python_tokenizer | |||
| from ..core.datatypes import mstype_to_detype | |||
| @@ -406,3 +407,25 @@ class ToNumber(cde.ToNumberOp): | |||
| data_type = mstype_to_detype(data_type) | |||
| self.data_type = str(data_type) | |||
| super().__init__(data_type) | |||
| class PythonTokenizer: | |||
| """ | |||
| Callable class to be used for user-defined string tokenizer. | |||
| Args: | |||
| tokenizer (Callable): Python function that takes a `str` and returns a list of `str` as tokens. | |||
| Examples: | |||
| >>> def my_tokenizer(line): | |||
| >>> return line.split() | |||
| >>> data = data.map(operations=PythonTokenizer(my_tokenizer)) | |||
| """ | |||
| @check_python_tokenizer | |||
| def __init__(self, tokenizer): | |||
| self.tokenizer = np.vectorize(lambda x: np.array(tokenizer(x), dtype='U'), signature='()->(n)') | |||
| def __call__(self, in_array): | |||
| in_array = to_str(in_array) | |||
| tokens = self.tokenizer(in_array) | |||
| return tokens | |||
| @@ -411,3 +411,25 @@ def check_to_number(method): | |||
| return method(self, **kwargs) | |||
| return new_method | |||
| def check_python_tokenizer(method): | |||
| """A wrapper that wraps a parameter check to the original function (PythonTokenizer).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| tokenizer = (list(args) + [None])[0] | |||
| if "tokenizer" in kwargs: | |||
| tokenizer = kwargs.get("tokenizer") | |||
| if tokenizer is None: | |||
| raise ValueError("tokenizer is a mandatory parameter.") | |||
| if not callable(tokenizer): | |||
| raise TypeError("tokenizer is not a callable python function") | |||
| kwargs["tokenizer"] = tokenizer | |||
| return method(self, **kwargs) | |||
| return new_method | |||
| @@ -0,0 +1,52 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing PythonTokenizer op in DE | |||
| """ | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.text as text | |||
| from mindspore import log as logger | |||
| DATA_FILE = "../data/dataset/testTokenizerData/1.txt" | |||
| def test_whitespace_tokenizer_ch(): | |||
| """ | |||
| Test PythonTokenizer | |||
| """ | |||
| whitespace_strs = [["Welcome", "to", "Beijing!"], | |||
| ["北京欢迎您!"], | |||
| ["我喜欢English!"], | |||
| [""]] | |||
| def my_tokenizer(line): | |||
| words = line.split() | |||
| if not words: | |||
| return [""] | |||
| return words | |||
| dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) | |||
| tokenizer = text.PythonTokenizer(my_tokenizer) | |||
| dataset = dataset.map(operations=tokenizer, num_parallel_workers=1) | |||
| tokens = [] | |||
| for i in dataset.create_dict_iterator(): | |||
| s = text.to_str(i['text']).tolist() | |||
| tokens.append(s) | |||
| logger.info("The out tokens is : {}".format(tokens)) | |||
| assert whitespace_strs == tokens | |||
| if __name__ == '__main__': | |||
| test_whitespace_tokenizer_ch() | |||