| @@ -77,6 +77,7 @@ | |||
| #include "minddata/dataset/text/kernels/jieba_tokenizer_op.h" | |||
| #include "minddata/dataset/text/kernels/lookup_op.h" | |||
| #include "minddata/dataset/text/kernels/ngram_op.h" | |||
| #include "minddata/dataset/text/kernels/sliding_window_op.h" | |||
| #include "minddata/dataset/text/kernels/to_number_op.h" | |||
| #include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h" | |||
| #include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h" | |||
| @@ -640,6 +641,9 @@ void bindTokenizerOps(py::module *m) { | |||
| py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken, | |||
| py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken), | |||
| py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets); | |||
| (void)py::class_<SlidingWindowOp, TensorOp, std::shared_ptr<SlidingWindowOp>>( | |||
| *m, "SlidingWindowOp", "TensorOp to apply sliding window to a 1-D Tensor.") | |||
| .def(py::init<uint32_t, int32_t>(), py::arg("width"), py::arg("axis")); | |||
| } | |||
| void bindDependIcuTokenizerOps(py::module *m) { | |||
| @@ -120,6 +120,7 @@ constexpr char kCaseFoldOp[] = "CaseFoldOp"; | |||
| constexpr char kJiebaTokenizerOp[] = "JiebaTokenizerOp"; | |||
| constexpr char kLookupOp[] = "LookupOp"; | |||
| constexpr char kNgramOp[] = "NgramOp"; | |||
| constexpr char kSlidingWindowOp[] = "SlidingWindowOp"; | |||
| constexpr char kNormalizeUTF8Op[] = "NormalizeUTF8Op"; | |||
| constexpr char kRegexReplaceOp[] = "RegexReplaceOp"; | |||
| constexpr char kRegexTokenizerOp[] = "RegexTokenizerOp"; | |||
| @@ -12,10 +12,12 @@ if (NOT (CMAKE_SYSTEM_NAME MATCHES "Windows")) | |||
| whitespace_tokenizer_op.cc) | |||
| endif() | |||
| add_library(text-kernels OBJECT | |||
| data_utils.cc | |||
| lookup_op.cc | |||
| jieba_tokenizer_op.cc | |||
| unicode_char_tokenizer_op.cc | |||
| ngram_op.cc | |||
| sliding_window_op.cc | |||
| wordpiece_tokenizer_op.cc | |||
| truncate_sequence_pair_op.cc | |||
| to_number_op.cc | |||
| @@ -0,0 +1,66 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/text/kernels/data_utils.h" | |||
| #include <algorithm> | |||
| #include <limits> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "minddata/dataset/core/pybind_support.h" | |||
| #include "minddata/dataset/kernels/data/type_cast_op.h" | |||
| #include "minddata/dataset/kernels/data/slice_op.h" | |||
| #include "minddata/dataset/kernels/data/concatenate_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status SlidingWindowHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, TensorShape out_shape, | |||
| uint32_t width, int32_t axis) { | |||
| // if the data row has fewer items than width, the corresponding result row will be empty | |||
| if (out_shape.Size() == 0) { | |||
| MS_LOG(WARNING) << "The data row has fewer items than width, the result will be empty."; | |||
| if (input->type().value() == DataType::DE_STRING) { | |||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(output, std::vector<std::string>{}, TensorShape({0}))); | |||
| } else { | |||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, TensorShape({0}), input->type())); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| axis = Tensor::HandleNeg(axis, input->shape().Size()); | |||
| int32_t axis_end = input->shape()[axis]; | |||
| std::shared_ptr<Tensor> tmp; | |||
| auto concatenate_op = std::make_unique<ConcatenateOp>(axis, nullptr, nullptr); | |||
| // Slice on specified axis and concatenate on new axis | |||
| for (int32_t i = 0; i + width <= axis_end; i++) { | |||
| auto slice_op = std::make_unique<SliceOp>(Slice(i, i + width, 1)); | |||
| slice_op->Compute(input, &tmp); | |||
| if (i == 0) { | |||
| *output = tmp; | |||
| } else { | |||
| TensorRow in({*output, tmp}); | |||
| TensorRow out_row; | |||
| concatenate_op->Compute(in, &out_row); | |||
| *output = out_row[0]; | |||
| } | |||
| } | |||
| (*output)->Reshape(out_shape); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_KERNELS_TEXT_DATA_UTILS_H_ | |||
| #define DATASET_KERNELS_TEXT_DATA_UTILS_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "minddata/dataset/core/constants.h" | |||
| #include "minddata/dataset/core/data_type.h" | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/core/cv_tensor.h" | |||
| #include "minddata/dataset/core/tensor_shape.h" | |||
| #include "minddata/dataset/core/tensor_row.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \brief Helper method that perform sliding window on input tensor. | |||
| /// \param[in] input - Input tensor. | |||
| /// \param[in] out_shape - Output shape of output tensor. | |||
| /// \param[in] width - The axis along which sliding window is computed. | |||
| /// \param[in] axis - The width of the window. | |||
| /// \param[out] output - Output tensor | |||
| /// \return Status return code | |||
| Status SlidingWindowHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, TensorShape out_shape, | |||
| uint32_t width, int32_t axis); | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_KERNELS_TEXT_DATA_UTILS_H_ | |||
| @@ -0,0 +1,57 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/text/kernels/sliding_window_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status SlidingWindowOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||
| IO_CHECK(input, output); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Rank() == 1, "SlidingWindosOp supports 1D Tensors only for now."); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(axis_ == 0 || axis_ == -1, "axis supports 0 or -1 only for now."); | |||
| std::vector<TensorShape> input_shape = {input->shape()}; | |||
| std::vector<TensorShape> output_shape = {TensorShape({})}; | |||
| RETURN_IF_NOT_OK(OutputShape(input_shape, output_shape)); | |||
| RETURN_IF_NOT_OK(SlidingWindowHelper(input, output, output_shape[0], width_, axis_)); | |||
| return Status::OK(); | |||
| } | |||
| Status SlidingWindowOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput(), "incorrect num of inputs\n"); | |||
| int32_t axis = Tensor::HandleNeg(axis_, inputs[0].Size()); | |||
| TensorShape input_shape = inputs[0]; | |||
| std::vector<dsize_t> output_shape_initializer; | |||
| // if a data row has fewer items than width, the corresponding result row will be empty. | |||
| if (input_shape[axis] >= width_) { | |||
| for (int32_t idx = 0; idx < input_shape.Size(); ++idx) { | |||
| if (idx != axis) { | |||
| output_shape_initializer.push_back(input_shape[idx]); | |||
| } else { | |||
| output_shape_initializer.push_back(input_shape[idx] - (width_ - 1)); | |||
| output_shape_initializer.push_back(width_); | |||
| } | |||
| } | |||
| } | |||
| outputs.pop_back(); | |||
| outputs.emplace_back(TensorShape(output_shape_initializer)); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(outputs.size() == NumOutput(), "incorrect num of outputs\n"); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,68 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_KERNELS_TEXT_SLIDING_WINDOW_OP_H_ | |||
| #define DATASET_KERNELS_TEXT_SLIDING_WINDOW_OP_H_ | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/kernels/tensor_op.h" | |||
| #include "minddata/dataset/text/kernels/data_utils.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class SlidingWindowOp : public TensorOp { | |||
| public: | |||
| /// \brief Constructor of SlidingWindowOp. | |||
| /// \param[in] width - The axis along which sliding window is computed. | |||
| /// \param[in] axis - The width of the window. | |||
| /// \return Status return code | |||
| explicit SlidingWindowOp(uint32_t width, int32_t axis = 0) : width_(width), axis_(axis) {} | |||
| /// \brief Destructor of SlidingWindowOp. | |||
| ~SlidingWindowOp() override = default; | |||
| /// \brief Perform sliding window to tensor. | |||
| /// \param[in] input - Input tensor of Op. | |||
| /// \param[out] output - output tensor of Op. | |||
| /// \return Status return code | |||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||
| /// \brief Calculate tensor shape for output tensor. | |||
| /// \param[in] inputs - Input tensor shapes. | |||
| /// \param[out] outputs - Output tensor shapes. | |||
| /// \return Status return code | |||
| Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override; | |||
| /// \brief Print args for debugging. | |||
| /// \param[in] out - std::ostream &out. | |||
| void Print(std::ostream &out) const override { out << "SliceWindowOp"; } | |||
| /// \brief Print name of op. | |||
| std::string Name() const override { return kSlidingWindowOp; } | |||
| private: | |||
| uint32_t width_; // The width of the window. Must be an integer and greater than zero. | |||
| int32_t axis_; // The axis along which sliding window is computed, only support 0/-1 for now. | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_KERNELS_TEXT_SLIDING_WINDOW_OP_H_ | |||
| @@ -19,13 +19,13 @@ utils provides some general methods for nlp text processing. | |||
| """ | |||
| import platform | |||
| from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer, Ngram, WordpieceTokenizer, TruncateSequencePair, \ | |||
| ToNumber | |||
| ToNumber, SlidingWindow | |||
| from .utils import to_str, to_bytes, JiebaMode, Vocab, NormalizeForm | |||
| __all__ = [ | |||
| "Lookup", "JiebaTokenizer", "UnicodeCharTokenizer", "Ngram", | |||
| "to_str", "to_bytes", "Vocab", "WordpieceTokenizer", "TruncateSequencePair", "ToNumber", | |||
| "PythonTokenizer" | |||
| "PythonTokenizer", "SlidingWindow" | |||
| ] | |||
| if platform.system().lower() != 'windows': | |||
| @@ -54,7 +54,7 @@ from .utils import JiebaMode, NormalizeForm, to_str | |||
| from .validators import check_lookup, check_jieba_add_dict, \ | |||
| check_jieba_add_word, check_jieba_init, check_with_offsets, check_unicode_script_tokenizer,\ | |||
| check_wordpiece_tokenizer, check_regex_tokenizer, check_basic_tokenizer, check_ngram, check_pair_truncate,\ | |||
| check_to_number, check_bert_tokenizer, check_python_tokenizer | |||
| check_to_number, check_bert_tokenizer, check_python_tokenizer, check_slidingwindow | |||
| from ..core.datatypes import mstype_to_detype | |||
| @@ -72,6 +72,34 @@ class Lookup(cde.LookupOp): | |||
| def __init__(self, vocab, unknown_token=None): | |||
| super().__init__(vocab, unknown_token) | |||
| class SlidingWindow(cde.SlidingWindowOp): | |||
| """ | |||
| TensorOp to construct a tensor from data (only 1-D for now), where each element in the dimension axis | |||
| is a slice of data starting at the corresponding position, with a specified width. | |||
| Args: | |||
| width (int): The width of the window. Must be an integer and greater than zero. | |||
| axis (int, optional): The axis along which sliding window is computed (default=0). | |||
| Examples: | |||
| >>> # Data before | |||
| >>> # | col1 | | |||
| >>> # +-------------+ | |||
| >>> # | [1,2,3,4,5] | | |||
| >>> # +-------------+ | |||
| >>> data = data.map(operations=SlidingWindow(3, 0)) | |||
| >>> # Data after | |||
| >>> # | col1 | | |||
| >>> # +-------------+ | |||
| >>> # | [[1,2,3], | | |||
| >>> # | [2,3,4], | | |||
| >>> # | [3,4,5]] | | |||
| >>> # +--------------+ | |||
| """ | |||
| @check_slidingwindow | |||
| def __init__(self, width, axis=0): | |||
| super().__init__(width=width, axis=axis) | |||
| class Ngram(cde.NgramOp): | |||
| """ | |||
| @@ -23,7 +23,7 @@ import mindspore._c_dataengine as cde | |||
| from mindspore._c_expression import typing | |||
| from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, \ | |||
| INT32_MAX, check_value, check_positive | |||
| INT32_MAX, check_value, check_positive, check_pos_int32 | |||
| def check_unique_list_of_words(words, arg_name): | |||
| @@ -328,6 +328,17 @@ def check_from_dataset(method): | |||
| return new_method | |||
| def check_slidingwindow(method): | |||
| """A wrapper that wrap a parameter checker to the original function(sliding window operation).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| [width, axis], _ = parse_user_args(method, *args, **kwargs) | |||
| check_pos_int32(width, "width") | |||
| type_check(axis, (int,), "axis") | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| def check_ngram(method): | |||
| """A wrapper that wraps a parameter checker to the original function.""" | |||
| @@ -92,6 +92,7 @@ SET(DE_UT_SRCS | |||
| perf_data_test.cc | |||
| c_api_test.cc | |||
| tensor_op_fusion_pass_test.cc | |||
| sliding_window_op_test.cc | |||
| ) | |||
| add_executable(de_ut_tests ${DE_UT_SRCS}) | |||
| @@ -0,0 +1,69 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/common.h" | |||
| #include "minddata/dataset/text/kernels/sliding_window_op.h" | |||
| #include "utils/log_adapter.h" | |||
| using namespace mindspore::dataset; | |||
| using mindspore::MsLogLevel::INFO; | |||
| using mindspore::ExceptionType::NoExceptionType; | |||
| using mindspore::LogStream; | |||
| class MindDataTestSlidingWindowOp : public UT::Common { | |||
| protected: | |||
| MindDataTestSlidingWindowOp() {} | |||
| }; | |||
| TEST_F(MindDataTestSlidingWindowOp, Compute) { | |||
| MS_LOG(INFO) << "Doing MindDataTestSlidingWindowOp->Compute."; | |||
| std::vector<std::string> strings = {"one", "two", "three", "four", "five", "six", "seven", "eight"}; | |||
| TensorShape shape({static_cast<dsize_t>(strings.size())}); | |||
| std::shared_ptr<Tensor> input = std::make_shared<Tensor>(strings, shape); | |||
| std::shared_ptr<Tensor> output; | |||
| std::unique_ptr<SlidingWindowOp> op(new SlidingWindowOp(3, 0)); | |||
| Status s = op->Compute(input, &output); | |||
| std::vector<std::string> out = {"one", "two", "three", "two", "three", "four", "three", "four", "five", | |||
| "four", "five", "six", "five", "six", "seven", "six", "seven", "eight"}; | |||
| std::shared_ptr<Tensor> expected = std::make_shared<Tensor>(out, TensorShape({6, 3})); | |||
| ASSERT_TRUE(output->shape() == expected->shape()); | |||
| ASSERT_TRUE(output->type() == expected->type()); | |||
| MS_LOG(DEBUG) << *output << std::endl; | |||
| MS_LOG(DEBUG) << *expected << std::endl; | |||
| ASSERT_TRUE(*output == *expected); | |||
| MS_LOG(INFO) << "MindDataTestSlidingWindowOp end."; | |||
| } | |||
| TEST_F(MindDataTestSlidingWindowOp, OutputShape) { | |||
| MS_LOG(INFO) << "Doing MindDataTestSlidingWindowOp->OutputShape."; | |||
| std::vector<std::string> strings = {"one", "two", "three", "four", "five", "six", "seven", "eight"}; | |||
| TensorShape shape({static_cast<dsize_t>(strings.size())}); | |||
| std::shared_ptr<Tensor> input = std::make_shared<Tensor>(strings, shape); | |||
| std::vector<TensorShape> input_shape = {input->shape()}; | |||
| std::vector<TensorShape> output_shape = {TensorShape({})}; | |||
| std::unique_ptr<SlidingWindowOp> op(new SlidingWindowOp(3, 0)); | |||
| Status s = op->OutputShape(input_shape, output_shape); | |||
| MS_LOG(DEBUG) << "input_shape" << input_shape[0]; | |||
| MS_LOG(DEBUG) << "output_shape" << output_shape[0]; | |||
| ASSERT_TRUE(output_shape[0] == TensorShape({6, 3})); | |||
| MS_LOG(INFO) << "MindDataTestSlidingWindowOp end."; | |||
| } | |||
| @@ -0,0 +1,105 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing SlidingWindow in mindspore.dataset | |||
| """ | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.text as text | |||
| def test_sliding_window_string(): | |||
| """ test sliding_window with string type""" | |||
| inputs = [["大", "家", "早", "上", "好"]] | |||
| expect = np.array([['大', '家'], ['家', '早'], ['早', '上'], ['上', '好']]) | |||
| dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False) | |||
| dataset = dataset.map(input_columns=["text"], operations=text.SlidingWindow(2, 0)) | |||
| result = [] | |||
| for data in dataset.create_dict_iterator(): | |||
| for i in range(data['text'].shape[0]): | |||
| result.append([]) | |||
| for j in range(data['text'].shape[1]): | |||
| result[i].append(data['text'][i][j].decode('utf8')) | |||
| result = np.array(result) | |||
| np.testing.assert_array_equal(result, expect) | |||
| def test_sliding_window_number(): | |||
| inputs = [1] | |||
| expect = np.array([[1]]) | |||
| def gen(nums): | |||
| yield (np.array(nums),) | |||
| dataset = ds.GeneratorDataset(gen(inputs), column_names=["number"]) | |||
| dataset = dataset.map(input_columns=["number"], operations=text.SlidingWindow(1, -1)) | |||
| for data in dataset.create_dict_iterator(): | |||
| np.testing.assert_array_equal(data['number'], expect) | |||
| def test_sliding_window_big_width(): | |||
| inputs = [[1, 2, 3, 4, 5]] | |||
| expect = np.array([]) | |||
| dataset = ds.NumpySlicesDataset(inputs, column_names=["number"], shuffle=False) | |||
| dataset = dataset.map(input_columns=["number"], operations=text.SlidingWindow(30, 0)) | |||
| for data in dataset.create_dict_iterator(): | |||
| np.testing.assert_array_equal(data['number'], expect) | |||
| def test_sliding_window_exception(): | |||
| try: | |||
| _ = text.SlidingWindow(0, 0) | |||
| assert False | |||
| except ValueError: | |||
| pass | |||
| try: | |||
| _ = text.SlidingWindow("1", 0) | |||
| assert False | |||
| except TypeError: | |||
| pass | |||
| try: | |||
| _ = text.SlidingWindow(1, "0") | |||
| assert False | |||
| except TypeError: | |||
| pass | |||
| try: | |||
| inputs = [[1, 2, 3, 4, 5]] | |||
| dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False) | |||
| dataset = dataset.map(input_columns=["text"], operations=text.SlidingWindow(3, -100)) | |||
| for _ in dataset.create_dict_iterator(): | |||
| pass | |||
| assert False | |||
| except RuntimeError as e: | |||
| assert "axis supports 0 or -1 only for now." in str(e) | |||
| try: | |||
| inputs = ["aa", "bb", "cc"] | |||
| dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False) | |||
| dataset = dataset.map(input_columns=["text"], operations=text.SlidingWindow(2, 0)) | |||
| for _ in dataset.create_dict_iterator(): | |||
| pass | |||
| assert False | |||
| except RuntimeError as e: | |||
| assert "SlidingWindosOp supports 1D Tensors only for now." in str(e) | |||
| if __name__ == '__main__': | |||
| test_sliding_window_string() | |||
| test_sliding_window_number() | |||
| test_sliding_window_big_width() | |||
| test_sliding_window_exception() | |||