Merge pull request !1365 from h.farahat/text_namespacetags/v0.3.0-alpha
| @@ -52,7 +52,7 @@ add_subdirectory(core) | |||
| add_subdirectory(kernels) | |||
| add_subdirectory(engine) | |||
| add_subdirectory(api) | |||
| add_subdirectory(nlp) | |||
| add_subdirectory(text) | |||
| ###################################################################### | |||
| ################### Create _c_dataengine Library ###################### | |||
| @@ -62,7 +62,6 @@ set(submodules | |||
| $<TARGET_OBJECTS:kernels> | |||
| $<TARGET_OBJECTS:kernels-image> | |||
| $<TARGET_OBJECTS:kernels-data> | |||
| $<TARGET_OBJECTS:kernels-text> | |||
| $<TARGET_OBJECTS:APItoPython> | |||
| $<TARGET_OBJECTS:engine-datasetops-source> | |||
| $<TARGET_OBJECTS:engine-datasetops-source-sampler> | |||
| @@ -70,8 +69,8 @@ set(submodules | |||
| $<TARGET_OBJECTS:engine-datasetops> | |||
| $<TARGET_OBJECTS:engine-opt> | |||
| $<TARGET_OBJECTS:engine> | |||
| $<TARGET_OBJECTS:nlp> | |||
| $<TARGET_OBJECTS:nlp-kernels> | |||
| $<TARGET_OBJECTS:text> | |||
| $<TARGET_OBJECTS:text-kernels> | |||
| ) | |||
| if (ENABLE_TDTQUE) | |||
| @@ -38,10 +38,6 @@ | |||
| #include "dataset/kernels/image/resize_op.h" | |||
| #include "dataset/kernels/image/uniform_aug_op.h" | |||
| #include "dataset/kernels/data/type_cast_op.h" | |||
| #include "dataset/kernels/text/jieba_tokenizer_op.h" | |||
| #include "dataset/kernels/text/unicode_char_tokenizer_op.h" | |||
| #include "dataset/nlp/vocab.h" | |||
| #include "dataset/nlp/kernels/lookup_op.h" | |||
| #include "dataset/engine/datasetops/source/cifar_op.h" | |||
| #include "dataset/engine/datasetops/source/image_folder_op.h" | |||
| #include "dataset/engine/datasetops/source/io_block.h" | |||
| @@ -63,6 +59,10 @@ | |||
| #include "dataset/engine/datasetops/source/voc_op.h" | |||
| #include "dataset/engine/gnn/graph.h" | |||
| #include "dataset/kernels/data/to_float16_op.h" | |||
| #include "dataset/text/kernels/jieba_tokenizer_op.h" | |||
| #include "dataset/text/kernels/unicode_char_tokenizer_op.h" | |||
| #include "dataset/text/vocab.h" | |||
| #include "dataset/text/kernels/lookup_op.h" | |||
| #include "dataset/util/random.h" | |||
| #include "mindrecord/include/shard_operator.h" | |||
| #include "mindrecord/include/shard_pk_sample.h" | |||
| @@ -577,9 +577,9 @@ PYBIND11_MODULE(_c_dataengine, m) { | |||
| .value("TEXTFILE", OpName::kTextFile); | |||
| (void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic()) | |||
| .value("DE_INTER_JIEBA_MIX", JiebaMode::kMix) | |||
| .value("DE_INTER_JIEBA_MP", JiebaMode::kMp) | |||
| .value("DE_INTER_JIEBA_HMM", JiebaMode::kHmm) | |||
| .value("DE_JIEBA_MIX", JiebaMode::kMix) | |||
| .value("DE_JIEBA_MP", JiebaMode::kMp) | |||
| .value("DE_JIEBA_HMM", JiebaMode::kHmm) | |||
| .export_values(); | |||
| (void)py::enum_<InterpolationMode>(m, "InterpolationMode", py::arithmetic()) | |||
| @@ -2,7 +2,6 @@ add_subdirectory(image) | |||
| add_subdirectory(data) | |||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| add_subdirectory(text) | |||
| add_library(kernels OBJECT | |||
| py_func_op.cc | |||
| tensor_op.cc) | |||
| @@ -1,7 +0,0 @@ | |||
| add_subdirectory(kernels) | |||
| add_library(nlp OBJECT | |||
| vocab.cc | |||
| ) | |||
| add_dependencies(nlp nlp-kernels) | |||
| @@ -1,3 +0,0 @@ | |||
| add_library(nlp-kernels OBJECT | |||
| lookup_op.cc | |||
| ) | |||
| @@ -0,0 +1,7 @@ | |||
| add_subdirectory(kernels) | |||
| add_library(text OBJECT | |||
| vocab.cc | |||
| ) | |||
| add_dependencies(text text-kernels) | |||
| @@ -1,6 +1,7 @@ | |||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| add_library(kernels-text OBJECT | |||
| jieba_tokenizer_op.cc | |||
| unicode_char_tokenizer_op.cc | |||
| ) | |||
| add_library(text-kernels OBJECT | |||
| lookup_op.cc | |||
| jieba_tokenizer_op.cc | |||
| unicode_char_tokenizer_op.cc | |||
| ) | |||
| @@ -13,7 +13,7 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/kernels/text/jieba_tokenizer_op.h" | |||
| #include "dataset/text/kernels/jieba_tokenizer_op.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| @@ -13,7 +13,7 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/nlp/kernels/lookup_op.h" | |||
| #include "dataset/text/kernels/lookup_op.h" | |||
| #include <string> | |||
| @@ -24,7 +24,7 @@ | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| #include "dataset/util/status.h" | |||
| #include "dataset/nlp/vocab.h" | |||
| #include "dataset/text/vocab.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -13,7 +13,7 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/kernels/text/unicode_char_tokenizer_op.h" | |||
| #include "dataset/text/kernels/unicode_char_tokenizer_op.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include <string_view> | |||
| @@ -17,7 +17,7 @@ | |||
| #include <map> | |||
| #include <utility> | |||
| #include "dataset/nlp/vocab.h" | |||
| #include "dataset/text/vocab.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -284,10 +284,10 @@ class Dataset: | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> import mindspore.dataset.transforms.text.utils as text | |||
| >>> import mindspore.dataset.text as text | |||
| >>> # declare a function which returns a Dataset object | |||
| >>> def flat_map_func(x): | |||
| >>> data_dir = text.as_text(x[0]) | |||
| >>> data_dir = text.to_str(x[0]) | |||
| >>> d = ds.ImageFolderDatasetV2(data_dir) | |||
| >>> return d | |||
| >>> # data is a Dataset object | |||
| @@ -15,5 +15,5 @@ | |||
| """ | |||
| mindspore.dataset.text | |||
| """ | |||
| from .c_transforms import * | |||
| from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer | |||
| from .utils import to_str, to_bytes, JiebaMode, Vocab | |||
| @@ -11,20 +11,40 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| This module c_transforms provides common nlp operations. | |||
| c transforms for all text related operators | |||
| """ | |||
| import os | |||
| import re | |||
| import mindspore._c_dataengine as cde | |||
| from .utils import JiebaMode | |||
| from .validators import check_jieba_add_dict, check_jieba_add_word, check_jieba_init | |||
| from .validators import check_lookup, check_jieba_add_dict, \ | |||
| check_jieba_add_word, check_jieba_init | |||
| class Lookup(cde.LookupOp): | |||
| """ | |||
| Lookup operator that looks up a word to an id | |||
| Args: | |||
| vocab(Vocab): a Vocab object | |||
| unknown(None,int): default id to lookup a word that is out of vocab | |||
| """ | |||
| @check_lookup | |||
| def __init__(self, vocab, unknown=None): | |||
| if unknown is None: | |||
| super().__init__(vocab) | |||
| else: | |||
| super().__init__(vocab, unknown) | |||
| DE_C_INTER_JIEBA_MODE = { | |||
| JiebaMode.MIX: cde.JiebaMode.DE_INTER_JIEBA_MIX, | |||
| JiebaMode.MP: cde.JiebaMode.DE_INTER_JIEBA_MP, | |||
| JiebaMode.HMM: cde.JiebaMode.DE_INTER_JIEBA_HMM | |||
| JiebaMode.MIX: cde.JiebaMode.DE_JIEBA_MIX, | |||
| JiebaMode.MP: cde.JiebaMode.DE_JIEBA_MP, | |||
| JiebaMode.HMM: cde.JiebaMode.DE_JIEBA_HMM | |||
| } | |||
| @@ -41,6 +61,7 @@ class JiebaTokenizer(cde.JiebaTokenizerOp): | |||
| "HMM" mode will tokenize with Hiddel Markov Model Segment algorithm, | |||
| "MIX" model will tokenize with a mix of MPSegment and HMMSegment algorithm. | |||
| """ | |||
| @check_jieba_init | |||
| def __init__(self, hmm_path, mp_path, mode=JiebaMode.MIX): | |||
| self.mode = mode | |||
| @@ -12,11 +12,14 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ | |||
| c transforms for all text related operators | |||
| Some basic function for nlp | |||
| """ | |||
| from enum import IntEnum | |||
| import mindspore._c_dataengine as cde | |||
| from .validators import check_lookup, check_from_list, check_from_dict, check_from_file | |||
| import numpy as np | |||
| from .validators import check_from_file, check_from_list, check_from_dict | |||
| class Vocab(cde.Vocab): | |||
| @@ -61,17 +64,43 @@ class Vocab(cde.Vocab): | |||
| return super().from_dict(word_dict) | |||
| class Lookup(cde.LookupOp): | |||
| def to_str(array, encoding='utf8'): | |||
| """ | |||
| Lookup operator that looks up a word to an id | |||
| Convert numpy array of `bytes` to array of `str` by decoding each element based on charset `encoding`. | |||
| Args: | |||
| vocab(Vocab): a Vocab object | |||
| unknown(None,int): default id to lookup a word that is out of vocab | |||
| array (numpy array): Array of type `bytes` representing strings. | |||
| encoding (string): Indicating the charset for decoding. | |||
| Returns: | |||
| Numpy array of `str`. | |||
| """ | |||
| if not isinstance(array, np.ndarray): | |||
| raise ValueError('input should be a numpy array') | |||
| return np.char.decode(array, encoding) | |||
| def to_bytes(array, encoding='utf8'): | |||
| """ | |||
| Convert numpy array of `str` to array of `bytes` by encoding each element based on charset `encoding`. | |||
| Args: | |||
| array (numpy array): Array of type `str` representing strings. | |||
| encoding (string): Indicating the charset for encoding. | |||
| Returns: | |||
| Numpy array of `bytes`. | |||
| """ | |||
| if not isinstance(array, np.ndarray): | |||
| raise ValueError('input should be a numpy array') | |||
| return np.char.encode(array, encoding) | |||
| @check_lookup | |||
| def __init__(self, vocab, unknown=None): | |||
| if unknown is None: | |||
| super().__init__(vocab) | |||
| else: | |||
| super().__init__(vocab, unknown) | |||
| class JiebaMode(IntEnum): | |||
| MIX = 0 | |||
| MP = 1 | |||
| HMM = 2 | |||
| @@ -17,8 +17,11 @@ validators for text ops | |||
| """ | |||
| from functools import wraps | |||
| import mindspore._c_dataengine as cde | |||
| from ..transforms.validators import check_uint32 | |||
| def check_lookup(method): | |||
| """A wrapper that wrap a parameter checker to the original function(crop operation).""" | |||
| @@ -106,3 +109,67 @@ def check_from_dict(method): | |||
| return method(self, **kwargs) | |||
| return new_method | |||
| def check_jieba_init(method): | |||
| """Wrapper method to check the parameters of jieba add word.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| hmm_path, mp_path, model = (list(args) + 3 * [None])[:3] | |||
| if "hmm_path" in kwargs: | |||
| hmm_path = kwargs.get("hmm_path") | |||
| if "mp_path" in kwargs: | |||
| mp_path = kwargs.get("mp_path") | |||
| if hmm_path is None: | |||
| raise ValueError( | |||
| "the dict of HMMSegment in cppjieba is not provided") | |||
| kwargs["hmm_path"] = hmm_path | |||
| if mp_path is None: | |||
| raise ValueError( | |||
| "the dict of MPSegment in cppjieba is not provided") | |||
| kwargs["mp_path"] = mp_path | |||
| if model is not None: | |||
| kwargs["model"] = model | |||
| return method(self, **kwargs) | |||
| return new_method | |||
| def check_jieba_add_word(method): | |||
| """Wrapper method to check the parameters of jieba add word.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| word, freq = (list(args) + 2 * [None])[:2] | |||
| if "word" in kwargs: | |||
| word = kwargs.get("word") | |||
| if "freq" in kwargs: | |||
| freq = kwargs.get("freq") | |||
| if word is None: | |||
| raise ValueError("word is not provided") | |||
| kwargs["word"] = word | |||
| if freq is not None: | |||
| check_uint32(freq) | |||
| kwargs["freq"] = freq | |||
| return method(self, **kwargs) | |||
| return new_method | |||
| def check_jieba_add_dict(method): | |||
| """Wrapper method to check the parameters of add dict""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| user_dict = (list(args) + [None])[0] | |||
| if "user_dict" in kwargs: | |||
| user_dict = kwargs.get("user_dict") | |||
| if user_dict is None: | |||
| raise ValueError("user_dict is not provided") | |||
| kwargs["user_dict"] = user_dict | |||
| return method(self, **kwargs) | |||
| return new_method | |||
| @@ -1,21 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ | |||
| This module is to support nlp augmentations. It includes two parts: | |||
| c_transforms and py_transforms. C_transforms is a high performance | |||
| image augmentation module which is developed with c++ opencv. Py_transforms | |||
| provide more kinds of image augmentations which is developed with python PIL. | |||
| """ | |||
| from .utils import as_text, JiebaMode | |||
| from . import c_transforms | |||
| @@ -1,43 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ | |||
| Some basic function for nlp | |||
| """ | |||
| from enum import IntEnum | |||
| import numpy as np | |||
| def as_text(array, encoding='utf8'): | |||
| """ | |||
| Convert data of array to unicode. | |||
| Args: | |||
| array (numpy array): Data of array should be ASCII values of each character after converted. | |||
| encoding (string): Indicating the charset for decoding. | |||
| Returns: | |||
| A 'str' object. | |||
| """ | |||
| if not isinstance(array, np.ndarray): | |||
| raise ValueError('input should be a numpy array') | |||
| decode = np.vectorize(lambda x: x.decode(encoding)) | |||
| return decode(array) | |||
| class JiebaMode(IntEnum): | |||
| MIX = 0 | |||
| MP = 1 | |||
| HMM = 2 | |||
| @@ -1,79 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Validators for TensorOps. | |||
| """ | |||
| from functools import wraps | |||
| from ...transforms.validators import check_uint32 | |||
| def check_jieba_init(method): | |||
| """Wrapper method to check the parameters of jieba add word.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| hmm_path, mp_path, model = (list(args) + 3 * [None])[:3] | |||
| if "hmm_path" in kwargs: | |||
| hmm_path = kwargs.get("hmm_path") | |||
| if "mp_path" in kwargs: | |||
| mp_path = kwargs.get("mp_path") | |||
| if hmm_path is None: | |||
| raise ValueError( | |||
| "the dict of HMMSegment in cppjieba is not provided") | |||
| kwargs["hmm_path"] = hmm_path | |||
| if mp_path is None: | |||
| raise ValueError( | |||
| "the dict of MPSegment in cppjieba is not provided") | |||
| kwargs["mp_path"] = mp_path | |||
| if model is not None: | |||
| kwargs["model"] = model | |||
| return method(self, **kwargs) | |||
| return new_method | |||
| def check_jieba_add_word(method): | |||
| """Wrapper method to check the parameters of jieba add word.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| word, freq = (list(args) + 2 * [None])[:2] | |||
| if "word" in kwargs: | |||
| word = kwargs.get("word") | |||
| if "freq" in kwargs: | |||
| freq = kwargs.get("freq") | |||
| if word is None: | |||
| raise ValueError("word is not provided") | |||
| kwargs["word"] = word | |||
| if freq is not None: | |||
| check_uint32(freq) | |||
| kwargs["freq"] = freq | |||
| return method(self, **kwargs) | |||
| return new_method | |||
| def check_jieba_add_dict(method): | |||
| """Wrapper method to check the parameters of add dict""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| user_dict = (list(args) + [None])[0] | |||
| if "user_dict" in kwargs: | |||
| user_dict = kwargs.get("user_dict") | |||
| if user_dict is None: | |||
| raise ValueError("user_dict is not provided") | |||
| kwargs["user_dict"] = user_dict | |||
| return method(self, **kwargs) | |||
| return new_method | |||
| @@ -18,7 +18,7 @@ | |||
| #include <string_view> | |||
| #include "common/common.h" | |||
| #include "dataset/kernels/text/jieba_tokenizer_op.h" | |||
| #include "dataset/text/kernels/jieba_tokenizer_op.h" | |||
| #include "gtest/gtest.h" | |||
| #include "utils/log_adapter.h" | |||
| @@ -18,7 +18,7 @@ | |||
| #include <string_view> | |||
| #include "common/common.h" | |||
| #include "dataset/kernels/text/unicode_char_tokenizer_op.h" | |||
| #include "dataset/text/kernels/unicode_char_tokenizer_op.h" | |||
| #include "gtest/gtest.h" | |||
| #include "utils/log_adapter.h" | |||
| @@ -0,0 +1,2 @@ | |||
| home is behind the world ahead | |||
| is behind home ahead world the | |||
| @@ -13,7 +13,6 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.text.utils as nlp | |||
| from mindspore import log as logger | |||
| DATA_FILE = "../data/dataset/testTextFileDataset/1.txt" | |||
| @@ -24,7 +24,6 @@ def test_flat_map_1(): | |||
| ''' | |||
| DATA_FILE records the path of image folders, load the images from them. | |||
| ''' | |||
| import mindspore.dataset.transforms.text.utils as nlp | |||
| def flat_map_func(x): | |||
| data_dir = x[0].item().decode('utf8') | |||
| @@ -45,7 +44,6 @@ def test_flat_map_2(): | |||
| ''' | |||
| Flatten 3D structure data | |||
| ''' | |||
| import mindspore.dataset.transforms.text.utils as nlp | |||
| def flat_map_func_1(x): | |||
| data_dir = x[0].item().decode('utf8') | |||
| @@ -27,7 +27,7 @@ import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as vision | |||
| from mindspore import log as logger | |||
| from mindspore.dataset.transforms.vision import Inter | |||
| from mindspore.dataset.transforms.text import as_text | |||
| from mindspore.dataset.text import to_str | |||
| from mindspore.mindrecord import FileWriter | |||
| FILES_NUM = 4 | |||
| @@ -73,7 +73,7 @@ def test_cv_minddataset_pk_sample_no_column(add_and_remove_cv_file): | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info("-------------- item[file_name]: \ | |||
| {}------------------------".format(as_text(item["file_name"]))) | |||
| {}------------------------".format(to_str(item["file_name"]))) | |||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| @@ -93,7 +93,7 @@ def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file): | |||
| logger.info("-------------- item[data]: \ | |||
| {}------------------------".format(item["data"][:10])) | |||
| logger.info("-------------- item[file_name]: \ | |||
| {}------------------------".format(as_text(item["file_name"]))) | |||
| {}------------------------".format(to_str(item["file_name"]))) | |||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| @@ -111,7 +111,7 @@ def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file): | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info("-------------- item[file_name]: \ | |||
| {}------------------------".format(as_text(item["file_name"]))) | |||
| {}------------------------".format(to_str(item["file_name"]))) | |||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| @@ -128,7 +128,7 @@ def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file): | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info("-------------- item[file_name]: \ | |||
| {}------------------------".format(as_text(item["file_name"]))) | |||
| {}------------------------".format(to_str(item["file_name"]))) | |||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.text as text | |||
| # this file contains "home is behind the world head" each word is 1 line | |||
| DATA_FILE = "../data/dataset/testVocab/words.txt" | |||
| VOCAB_FILE = "../data/dataset/testVocab/vocab_list.txt" | |||
| HMM_FILE = "../data/dataset/jiebadict/hmm_model.utf8" | |||
| MP_FILE = "../data/dataset/jiebadict/jieba.dict.utf8" | |||
| def test_on_tokenized_line(): | |||
| data = ds.TextFileDataset("../data/dataset/testVocab/lines.txt", shuffle=False) | |||
| jieba_op = text.JiebaTokenizer(HMM_FILE, MP_FILE, mode=text.JiebaMode.MP) | |||
| with open(VOCAB_FILE, 'r') as f: | |||
| for line in f: | |||
| word = line.split(',')[0] | |||
| jieba_op.add_word(word) | |||
| data = data.map(input_columns=["text"], operations=jieba_op) | |||
| vocab = text.Vocab.from_file(VOCAB_FILE, ",") | |||
| lookup = text.Lookup(vocab) | |||
| data = data.map(input_columns=["text"], operations=lookup) | |||
| res = np.array([[10, 1, 11, 1, 12, 1, 15, 1, 13, 1, 14], | |||
| [11, 1, 12, 1, 10, 1, 14, 1, 13, 1, 15]], dtype=np.int32) | |||
| for i, d in enumerate(data.create_dict_iterator()): | |||
| np.testing.assert_array_equal(d["text"], res[i]), i | |||
| if __name__ == '__main__': | |||
| test_on_tokenized_line() | |||
| @@ -14,8 +14,8 @@ | |||
| # ============================================================================== | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| from mindspore.dataset.transforms.text.c_transforms import JiebaTokenizer | |||
| from mindspore.dataset.transforms.text.utils import JiebaMode, as_text | |||
| from mindspore.dataset.text import JiebaTokenizer | |||
| from mindspore.dataset.text import JiebaMode, to_str | |||
| DATA_FILE = "../data/dataset/testJiebaDataset/3.txt" | |||
| DATA_ALL_FILE = "../data/dataset/testJiebaDataset/*" | |||
| @@ -33,7 +33,7 @@ def test_jieba_1(): | |||
| expect = ['今天天气', '太好了', '我们', '一起', '去', '外面', '玩吧'] | |||
| ret = [] | |||
| for i in data.create_dict_iterator(): | |||
| ret = as_text(i["text"]) | |||
| ret = to_str(i["text"]) | |||
| for index, item in enumerate(ret): | |||
| assert item == expect[index] | |||
| @@ -46,7 +46,7 @@ def test_jieba_1_1(): | |||
| operations=jieba_op, num_parallel_workers=1) | |||
| expect = ['今天', '天气', '太', '好', '了', '我们', '一起', '去', '外面', '玩', '吧'] | |||
| for i in data.create_dict_iterator(): | |||
| ret = as_text(i["text"]) | |||
| ret = to_str(i["text"]) | |||
| for index, item in enumerate(ret): | |||
| assert item == expect[index] | |||
| @@ -59,7 +59,7 @@ def test_jieba_1_2(): | |||
| operations=jieba_op, num_parallel_workers=1) | |||
| expect = ['今天天气', '太好了', '我们', '一起', '去', '外面', '玩吧'] | |||
| for i in data.create_dict_iterator(): | |||
| ret = as_text(i["text"]) | |||
| ret = to_str(i["text"]) | |||
| for index, item in enumerate(ret): | |||
| assert item == expect[index] | |||
| @@ -74,7 +74,7 @@ def test_jieba_2(): | |||
| data = data.map(input_columns=["text"], | |||
| operations=jieba_op, num_parallel_workers=2) | |||
| for i in data.create_dict_iterator(): | |||
| ret = as_text(i["text"]) | |||
| ret = to_str(i["text"]) | |||
| for index, item in enumerate(ret): | |||
| assert item == expect[index] | |||
| @@ -89,7 +89,7 @@ def test_jieba_2_1(): | |||
| operations=jieba_op, num_parallel_workers=2) | |||
| expect = ['男默女泪', '市', '长江大桥'] | |||
| for i in data.create_dict_iterator(): | |||
| ret = as_text(i["text"]) | |||
| ret = to_str(i["text"]) | |||
| for index, item in enumerate(ret): | |||
| assert item == expect[index] | |||
| @@ -113,7 +113,7 @@ def test_jieba_2_3(): | |||
| operations=jieba_op, num_parallel_workers=2) | |||
| expect = ['江州', '市长', '江大桥', '参加', '了', '长江大桥', '的', '通车', '仪式'] | |||
| for i in data.create_dict_iterator(): | |||
| ret = as_text(i["text"]) | |||
| ret = to_str(i["text"]) | |||
| for index, item in enumerate(ret): | |||
| assert item == expect[index] | |||
| @@ -131,7 +131,7 @@ def test_jieba_3(): | |||
| operations=jieba_op, num_parallel_workers=1) | |||
| expect = ['男默女泪', '市', '长江大桥'] | |||
| for i in data.create_dict_iterator(): | |||
| ret = as_text(i["text"]) | |||
| ret = to_str(i["text"]) | |||
| for index, item in enumerate(ret): | |||
| assert item == expect[index] | |||
| @@ -150,7 +150,7 @@ def test_jieba_3_1(): | |||
| operations=jieba_op, num_parallel_workers=1) | |||
| expect = ['男默女泪', '市长', '江大桥'] | |||
| for i in data.create_dict_iterator(): | |||
| ret = as_text(i["text"]) | |||
| ret = to_str(i["text"]) | |||
| for index, item in enumerate(ret): | |||
| assert item == expect[index] | |||
| @@ -166,7 +166,7 @@ def test_jieba_4(): | |||
| operations=jieba_op, num_parallel_workers=1) | |||
| expect = ['今天天气', '太好了', '我们', '一起', '去', '外面', '玩吧'] | |||
| for i in data.create_dict_iterator(): | |||
| ret = as_text(i["text"]) | |||
| ret = to_str(i["text"]) | |||
| for index, item in enumerate(ret): | |||
| assert item == expect[index] | |||
| @@ -192,7 +192,7 @@ def test_jieba_5(): | |||
| operations=jieba_op, num_parallel_workers=1) | |||
| expect = ['江州', '市长', '江大桥', '参加', '了', '长江大桥', '的', '通车', '仪式'] | |||
| for i in data.create_dict_iterator(): | |||
| ret = as_text(i["text"]) | |||
| ret = to_str(i["text"]) | |||
| for index, item in enumerate(ret): | |||
| assert item == expect[index] | |||
| @@ -203,7 +203,7 @@ def gen(): | |||
| def pytoken_op(input_data): | |||
| te = str(as_text(input_data)) | |||
| te = str(to_str(input_data)) | |||
| tokens = [] | |||
| tokens.append(te[:5].encode("UTF8")) | |||
| tokens.append(te[5:10].encode("UTF8")) | |||
| @@ -217,7 +217,7 @@ def test_jieba_6(): | |||
| operations=pytoken_op, num_parallel_workers=1) | |||
| expect = ['今天天气太', '好了我们一', '起去外面玩吧'] | |||
| for i in data.create_dict_iterator(): | |||
| ret = as_text(i["text"]) | |||
| ret = to_str(i["text"]) | |||
| for index, item in enumerate(ret): | |||
| assert item == expect[index] | |||
| @@ -16,6 +16,8 @@ import mindspore._c_dataengine as cde | |||
| import numpy as np | |||
| import pytest | |||
| from mindspore.dataset.text import to_str, to_bytes | |||
| import mindspore.dataset as ds | |||
| import mindspore.common.dtype as mstype | |||
| @@ -65,7 +67,8 @@ def test_map(): | |||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||
| def split(b): | |||
| splits = b.item().decode("utf8").split() | |||
| s = to_str(b) | |||
| splits = s.item().split() | |||
| return np.array(splits, dtype='S') | |||
| data = data.map(input_columns=["col"], operations=split) | |||
| @@ -74,11 +77,20 @@ def test_map(): | |||
| np.testing.assert_array_equal(d[0], expected) | |||
| def as_str(arr): | |||
| def decode(s): return s.decode("utf8") | |||
| def test_map2(): | |||
| def gen(): | |||
| yield np.array(["ab cde 121"], dtype='S'), | |||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||
| def upper(b): | |||
| out = np.char.upper(b) | |||
| return out | |||
| decode_v = np.vectorize(decode) | |||
| return decode_v(arr) | |||
| data = data.map(input_columns=["col"], operations=upper) | |||
| expected = np.array(["AB CDE 121"], dtype='S') | |||
| for d in data: | |||
| np.testing.assert_array_equal(d[0], expected) | |||
| line = np.array(["This is a text file.", | |||
| @@ -106,9 +118,9 @@ def test_tfrecord1(): | |||
| assert d["line"].shape == line[i].shape | |||
| assert d["words"].shape == words[i].shape | |||
| assert d["chinese"].shape == chinese[i].shape | |||
| np.testing.assert_array_equal(line[i], as_str(d["line"])) | |||
| np.testing.assert_array_equal(words[i], as_str(d["words"])) | |||
| np.testing.assert_array_equal(chinese[i], as_str(d["chinese"])) | |||
| np.testing.assert_array_equal(line[i], to_str(d["line"])) | |||
| np.testing.assert_array_equal(words[i], to_str(d["words"])) | |||
| np.testing.assert_array_equal(chinese[i], to_str(d["chinese"])) | |||
| def test_tfrecord2(): | |||
| @@ -118,9 +130,9 @@ def test_tfrecord2(): | |||
| assert d["line"].shape == line[i].shape | |||
| assert d["words"].shape == words[i].shape | |||
| assert d["chinese"].shape == chinese[i].shape | |||
| np.testing.assert_array_equal(line[i], as_str(d["line"])) | |||
| np.testing.assert_array_equal(words[i], as_str(d["words"])) | |||
| np.testing.assert_array_equal(chinese[i], as_str(d["chinese"])) | |||
| np.testing.assert_array_equal(line[i], to_str(d["line"])) | |||
| np.testing.assert_array_equal(words[i], to_str(d["words"])) | |||
| np.testing.assert_array_equal(chinese[i], to_str(d["chinese"])) | |||
| def test_tfrecord3(): | |||
| @@ -135,9 +147,9 @@ def test_tfrecord3(): | |||
| assert d["line"].shape == line[i].shape | |||
| assert d["words"].shape == words[i].reshape([2, 2]).shape | |||
| assert d["chinese"].shape == chinese[i].shape | |||
| np.testing.assert_array_equal(line[i], as_str(d["line"])) | |||
| np.testing.assert_array_equal(words[i].reshape([2, 2]), as_str(d["words"])) | |||
| np.testing.assert_array_equal(chinese[i], as_str(d["chinese"])) | |||
| np.testing.assert_array_equal(line[i], to_str(d["line"])) | |||
| np.testing.assert_array_equal(words[i].reshape([2, 2]), to_str(d["words"])) | |||
| np.testing.assert_array_equal(chinese[i], to_str(d["chinese"])) | |||
| def create_text_mindrecord(): | |||
| @@ -167,16 +179,17 @@ def test_mindrecord(): | |||
| for i, d in enumerate(data.create_dict_iterator()): | |||
| assert d["english"].shape == line[i].shape | |||
| assert d["chinese"].shape == chinese[i].shape | |||
| np.testing.assert_array_equal(line[i], as_str(d["english"])) | |||
| np.testing.assert_array_equal(chinese[i], as_str(d["chinese"])) | |||
| np.testing.assert_array_equal(line[i], to_str(d["english"])) | |||
| np.testing.assert_array_equal(chinese[i], to_str(d["chinese"])) | |||
| if __name__ == '__main__': | |||
| # test_generator() | |||
| # test_basic() | |||
| # test_batching_strings() | |||
| test_generator() | |||
| test_basic() | |||
| test_batching_strings() | |||
| test_map() | |||
| # test_tfrecord1() | |||
| # test_tfrecord2() | |||
| # test_tfrecord3() | |||
| # test_mindrecord() | |||
| test_map2() | |||
| test_tfrecord1() | |||
| test_tfrecord2() | |||
| test_tfrecord3() | |||
| test_mindrecord() | |||
| @@ -17,8 +17,7 @@ Testing UnicodeCharTokenizer op in DE | |||
| """ | |||
| import mindspore.dataset as ds | |||
| from mindspore import log as logger | |||
| import mindspore.dataset.transforms.text.c_transforms as nlp | |||
| import mindspore.dataset.transforms.text.utils as nlp_util | |||
| import mindspore.dataset.text as nlp | |||
| DATA_FILE = "../data/dataset/testTokenizerData/1.txt" | |||
| @@ -43,7 +42,7 @@ def test_unicode_char_tokenizer(): | |||
| dataset = dataset.map(operations=tokenizer) | |||
| tokens = [] | |||
| for i in dataset.create_dict_iterator(): | |||
| text = nlp_util.as_text(i['text']).tolist() | |||
| text = nlp.to_str(i['text']).tolist() | |||
| tokens.append(text) | |||
| logger.info("The out tokens is : {}".format(tokens)) | |||
| assert split_by_unicode_char(input_strs) == tokens | |||