| @@ -77,6 +77,7 @@ | |||||
| #include "minddata/dataset/text/kernels/jieba_tokenizer_op.h" | #include "minddata/dataset/text/kernels/jieba_tokenizer_op.h" | ||||
| #include "minddata/dataset/text/kernels/lookup_op.h" | #include "minddata/dataset/text/kernels/lookup_op.h" | ||||
| #include "minddata/dataset/text/kernels/ngram_op.h" | #include "minddata/dataset/text/kernels/ngram_op.h" | ||||
| #include "minddata/dataset/text/kernels/sliding_window_op.h" | |||||
| #include "minddata/dataset/text/kernels/to_number_op.h" | #include "minddata/dataset/text/kernels/to_number_op.h" | ||||
| #include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h" | #include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h" | ||||
| #include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h" | #include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h" | ||||
| @@ -640,6 +641,9 @@ void bindTokenizerOps(py::module *m) { | |||||
| py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken, | py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken, | ||||
| py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken), | py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken), | ||||
| py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets); | py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets); | ||||
| (void)py::class_<SlidingWindowOp, TensorOp, std::shared_ptr<SlidingWindowOp>>( | |||||
| *m, "SlidingWindowOp", "TensorOp to apply sliding window to a 1-D Tensor.") | |||||
| .def(py::init<uint32_t, int32_t>(), py::arg("width"), py::arg("axis")); | |||||
| } | } | ||||
| void bindDependIcuTokenizerOps(py::module *m) { | void bindDependIcuTokenizerOps(py::module *m) { | ||||
| @@ -120,6 +120,7 @@ constexpr char kCaseFoldOp[] = "CaseFoldOp"; | |||||
| constexpr char kJiebaTokenizerOp[] = "JiebaTokenizerOp"; | constexpr char kJiebaTokenizerOp[] = "JiebaTokenizerOp"; | ||||
| constexpr char kLookupOp[] = "LookupOp"; | constexpr char kLookupOp[] = "LookupOp"; | ||||
| constexpr char kNgramOp[] = "NgramOp"; | constexpr char kNgramOp[] = "NgramOp"; | ||||
| constexpr char kSlidingWindowOp[] = "SlidingWindowOp"; | |||||
| constexpr char kNormalizeUTF8Op[] = "NormalizeUTF8Op"; | constexpr char kNormalizeUTF8Op[] = "NormalizeUTF8Op"; | ||||
| constexpr char kRegexReplaceOp[] = "RegexReplaceOp"; | constexpr char kRegexReplaceOp[] = "RegexReplaceOp"; | ||||
| constexpr char kRegexTokenizerOp[] = "RegexTokenizerOp"; | constexpr char kRegexTokenizerOp[] = "RegexTokenizerOp"; | ||||
| @@ -12,10 +12,12 @@ if (NOT (CMAKE_SYSTEM_NAME MATCHES "Windows")) | |||||
| whitespace_tokenizer_op.cc) | whitespace_tokenizer_op.cc) | ||||
| endif() | endif() | ||||
| add_library(text-kernels OBJECT | add_library(text-kernels OBJECT | ||||
| data_utils.cc | |||||
| lookup_op.cc | lookup_op.cc | ||||
| jieba_tokenizer_op.cc | jieba_tokenizer_op.cc | ||||
| unicode_char_tokenizer_op.cc | unicode_char_tokenizer_op.cc | ||||
| ngram_op.cc | ngram_op.cc | ||||
| sliding_window_op.cc | |||||
| wordpiece_tokenizer_op.cc | wordpiece_tokenizer_op.cc | ||||
| truncate_sequence_pair_op.cc | truncate_sequence_pair_op.cc | ||||
| to_number_op.cc | to_number_op.cc | ||||
| @@ -0,0 +1,66 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/text/kernels/data_utils.h" | |||||
| #include <algorithm> | |||||
| #include <limits> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/core/pybind_support.h" | |||||
| #include "minddata/dataset/kernels/data/type_cast_op.h" | |||||
| #include "minddata/dataset/kernels/data/slice_op.h" | |||||
| #include "minddata/dataset/kernels/data/concatenate_op.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| Status SlidingWindowHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, TensorShape out_shape, | |||||
| uint32_t width, int32_t axis) { | |||||
| // if the data row has fewer items than width, the corresponding result row will be empty | |||||
| if (out_shape.Size() == 0) { | |||||
| MS_LOG(WARNING) << "The data row has fewer items than width, the result will be empty."; | |||||
| if (input->type().value() == DataType::DE_STRING) { | |||||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(output, std::vector<std::string>{}, TensorShape({0}))); | |||||
| } else { | |||||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, TensorShape({0}), input->type())); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| axis = Tensor::HandleNeg(axis, input->shape().Size()); | |||||
| int32_t axis_end = input->shape()[axis]; | |||||
| std::shared_ptr<Tensor> tmp; | |||||
| auto concatenate_op = std::make_unique<ConcatenateOp>(axis, nullptr, nullptr); | |||||
| // Slice on specified axis and concatenate on new axis | |||||
| for (int32_t i = 0; i + width <= axis_end; i++) { | |||||
| auto slice_op = std::make_unique<SliceOp>(Slice(i, i + width, 1)); | |||||
| slice_op->Compute(input, &tmp); | |||||
| if (i == 0) { | |||||
| *output = tmp; | |||||
| } else { | |||||
| TensorRow in({*output, tmp}); | |||||
| TensorRow out_row; | |||||
| concatenate_op->Compute(in, &out_row); | |||||
| *output = out_row[0]; | |||||
| } | |||||
| } | |||||
| (*output)->Reshape(out_shape); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,43 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DATASET_KERNELS_TEXT_DATA_UTILS_H_ | |||||
| #define DATASET_KERNELS_TEXT_DATA_UTILS_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/util/status.h" | |||||
| #include "minddata/dataset/core/constants.h" | |||||
| #include "minddata/dataset/core/data_type.h" | |||||
| #include "minddata/dataset/core/tensor.h" | |||||
| #include "minddata/dataset/core/cv_tensor.h" | |||||
| #include "minddata/dataset/core/tensor_shape.h" | |||||
| #include "minddata/dataset/core/tensor_row.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| /// \brief Helper method that perform sliding window on input tensor. | |||||
| /// \param[in] input - Input tensor. | |||||
| /// \param[in] out_shape - Output shape of output tensor. | |||||
| /// \param[in] width - The axis along which sliding window is computed. | |||||
| /// \param[in] axis - The width of the window. | |||||
| /// \param[out] output - Output tensor | |||||
| /// \return Status return code | |||||
| Status SlidingWindowHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, TensorShape out_shape, | |||||
| uint32_t width, int32_t axis); | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // DATASET_KERNELS_TEXT_DATA_UTILS_H_ | |||||
| @@ -0,0 +1,57 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/text/kernels/sliding_window_op.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| Status SlidingWindowOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||||
| IO_CHECK(input, output); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Rank() == 1, "SlidingWindosOp supports 1D Tensors only for now."); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(axis_ == 0 || axis_ == -1, "axis supports 0 or -1 only for now."); | |||||
| std::vector<TensorShape> input_shape = {input->shape()}; | |||||
| std::vector<TensorShape> output_shape = {TensorShape({})}; | |||||
| RETURN_IF_NOT_OK(OutputShape(input_shape, output_shape)); | |||||
| RETURN_IF_NOT_OK(SlidingWindowHelper(input, output, output_shape[0], width_, axis_)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status SlidingWindowOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput(), "incorrect num of inputs\n"); | |||||
| int32_t axis = Tensor::HandleNeg(axis_, inputs[0].Size()); | |||||
| TensorShape input_shape = inputs[0]; | |||||
| std::vector<dsize_t> output_shape_initializer; | |||||
| // if a data row has fewer items than width, the corresponding result row will be empty. | |||||
| if (input_shape[axis] >= width_) { | |||||
| for (int32_t idx = 0; idx < input_shape.Size(); ++idx) { | |||||
| if (idx != axis) { | |||||
| output_shape_initializer.push_back(input_shape[idx]); | |||||
| } else { | |||||
| output_shape_initializer.push_back(input_shape[idx] - (width_ - 1)); | |||||
| output_shape_initializer.push_back(width_); | |||||
| } | |||||
| } | |||||
| } | |||||
| outputs.pop_back(); | |||||
| outputs.emplace_back(TensorShape(output_shape_initializer)); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(outputs.size() == NumOutput(), "incorrect num of outputs\n"); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,68 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DATASET_KERNELS_TEXT_SLIDING_WINDOW_OP_H_ | |||||
| #define DATASET_KERNELS_TEXT_SLIDING_WINDOW_OP_H_ | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/core/tensor.h" | |||||
| #include "minddata/dataset/kernels/tensor_op.h" | |||||
| #include "minddata/dataset/text/kernels/data_utils.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class SlidingWindowOp : public TensorOp { | |||||
| public: | |||||
| /// \brief Constructor of SlidingWindowOp. | |||||
| /// \param[in] width - The axis along which sliding window is computed. | |||||
| /// \param[in] axis - The width of the window. | |||||
| /// \return Status return code | |||||
| explicit SlidingWindowOp(uint32_t width, int32_t axis = 0) : width_(width), axis_(axis) {} | |||||
| /// \brief Destructor of SlidingWindowOp. | |||||
| ~SlidingWindowOp() override = default; | |||||
| /// \brief Perform sliding window to tensor. | |||||
| /// \param[in] input - Input tensor of Op. | |||||
| /// \param[out] output - output tensor of Op. | |||||
| /// \return Status return code | |||||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||||
| /// \brief Calculate tensor shape for output tensor. | |||||
| /// \param[in] inputs - Input tensor shapes. | |||||
| /// \param[out] outputs - Output tensor shapes. | |||||
| /// \return Status return code | |||||
| Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override; | |||||
| /// \brief Print args for debugging. | |||||
| /// \param[in] out - std::ostream &out. | |||||
| void Print(std::ostream &out) const override { out << "SliceWindowOp"; } | |||||
| /// \brief Print name of op. | |||||
| std::string Name() const override { return kSlidingWindowOp; } | |||||
| private: | |||||
| uint32_t width_; // The width of the window. Must be an integer and greater than zero. | |||||
| int32_t axis_; // The axis along which sliding window is computed, only support 0/-1 for now. | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // DATASET_KERNELS_TEXT_SLIDING_WINDOW_OP_H_ | |||||
| @@ -19,13 +19,13 @@ utils provides some general methods for nlp text processing. | |||||
| """ | """ | ||||
| import platform | import platform | ||||
| from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer, Ngram, WordpieceTokenizer, TruncateSequencePair, \ | from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer, Ngram, WordpieceTokenizer, TruncateSequencePair, \ | ||||
| ToNumber | |||||
| ToNumber, SlidingWindow | |||||
| from .utils import to_str, to_bytes, JiebaMode, Vocab, NormalizeForm | from .utils import to_str, to_bytes, JiebaMode, Vocab, NormalizeForm | ||||
| __all__ = [ | __all__ = [ | ||||
| "Lookup", "JiebaTokenizer", "UnicodeCharTokenizer", "Ngram", | "Lookup", "JiebaTokenizer", "UnicodeCharTokenizer", "Ngram", | ||||
| "to_str", "to_bytes", "Vocab", "WordpieceTokenizer", "TruncateSequencePair", "ToNumber", | "to_str", "to_bytes", "Vocab", "WordpieceTokenizer", "TruncateSequencePair", "ToNumber", | ||||
| "PythonTokenizer" | |||||
| "PythonTokenizer", "SlidingWindow" | |||||
| ] | ] | ||||
| if platform.system().lower() != 'windows': | if platform.system().lower() != 'windows': | ||||
| @@ -54,7 +54,7 @@ from .utils import JiebaMode, NormalizeForm, to_str | |||||
| from .validators import check_lookup, check_jieba_add_dict, \ | from .validators import check_lookup, check_jieba_add_dict, \ | ||||
| check_jieba_add_word, check_jieba_init, check_with_offsets, check_unicode_script_tokenizer,\ | check_jieba_add_word, check_jieba_init, check_with_offsets, check_unicode_script_tokenizer,\ | ||||
| check_wordpiece_tokenizer, check_regex_tokenizer, check_basic_tokenizer, check_ngram, check_pair_truncate,\ | check_wordpiece_tokenizer, check_regex_tokenizer, check_basic_tokenizer, check_ngram, check_pair_truncate,\ | ||||
| check_to_number, check_bert_tokenizer, check_python_tokenizer | |||||
| check_to_number, check_bert_tokenizer, check_python_tokenizer, check_slidingwindow | |||||
| from ..core.datatypes import mstype_to_detype | from ..core.datatypes import mstype_to_detype | ||||
| @@ -72,6 +72,34 @@ class Lookup(cde.LookupOp): | |||||
| def __init__(self, vocab, unknown_token=None): | def __init__(self, vocab, unknown_token=None): | ||||
| super().__init__(vocab, unknown_token) | super().__init__(vocab, unknown_token) | ||||
| class SlidingWindow(cde.SlidingWindowOp): | |||||
| """ | |||||
| TensorOp to construct a tensor from data (only 1-D for now), where each element in the dimension axis | |||||
| is a slice of data starting at the corresponding position, with a specified width. | |||||
| Args: | |||||
| width (int): The width of the window. Must be an integer and greater than zero. | |||||
| axis (int, optional): The axis along which sliding window is computed (default=0). | |||||
| Examples: | |||||
| >>> # Data before | |||||
| >>> # | col1 | | |||||
| >>> # +-------------+ | |||||
| >>> # | [1,2,3,4,5] | | |||||
| >>> # +-------------+ | |||||
| >>> data = data.map(operations=SlidingWindow(3, 0)) | |||||
| >>> # Data after | |||||
| >>> # | col1 | | |||||
| >>> # +-------------+ | |||||
| >>> # | [[1,2,3], | | |||||
| >>> # | [2,3,4], | | |||||
| >>> # | [3,4,5]] | | |||||
| >>> # +--------------+ | |||||
| """ | |||||
| @check_slidingwindow | |||||
| def __init__(self, width, axis=0): | |||||
| super().__init__(width=width, axis=axis) | |||||
| class Ngram(cde.NgramOp): | class Ngram(cde.NgramOp): | ||||
| """ | """ | ||||
| @@ -23,7 +23,7 @@ import mindspore._c_dataengine as cde | |||||
| from mindspore._c_expression import typing | from mindspore._c_expression import typing | ||||
| from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, \ | from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, \ | ||||
| INT32_MAX, check_value, check_positive | |||||
| INT32_MAX, check_value, check_positive, check_pos_int32 | |||||
| def check_unique_list_of_words(words, arg_name): | def check_unique_list_of_words(words, arg_name): | ||||
| @@ -328,6 +328,17 @@ def check_from_dataset(method): | |||||
| return new_method | return new_method | ||||
| def check_slidingwindow(method): | |||||
| """A wrapper that wrap a parameter checker to the original function(sliding window operation).""" | |||||
| @wraps(method) | |||||
| def new_method(self, *args, **kwargs): | |||||
| [width, axis], _ = parse_user_args(method, *args, **kwargs) | |||||
| check_pos_int32(width, "width") | |||||
| type_check(axis, (int,), "axis") | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | |||||
| def check_ngram(method): | def check_ngram(method): | ||||
| """A wrapper that wraps a parameter checker to the original function.""" | """A wrapper that wraps a parameter checker to the original function.""" | ||||
| @@ -92,6 +92,7 @@ SET(DE_UT_SRCS | |||||
| perf_data_test.cc | perf_data_test.cc | ||||
| c_api_test.cc | c_api_test.cc | ||||
| tensor_op_fusion_pass_test.cc | tensor_op_fusion_pass_test.cc | ||||
| sliding_window_op_test.cc | |||||
| ) | ) | ||||
| add_executable(de_ut_tests ${DE_UT_SRCS}) | add_executable(de_ut_tests ${DE_UT_SRCS}) | ||||
| @@ -0,0 +1,69 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/common.h" | |||||
| #include "minddata/dataset/text/kernels/sliding_window_op.h" | |||||
| #include "utils/log_adapter.h" | |||||
| using namespace mindspore::dataset; | |||||
| using mindspore::MsLogLevel::INFO; | |||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::LogStream; | |||||
| class MindDataTestSlidingWindowOp : public UT::Common { | |||||
| protected: | |||||
| MindDataTestSlidingWindowOp() {} | |||||
| }; | |||||
| TEST_F(MindDataTestSlidingWindowOp, Compute) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestSlidingWindowOp->Compute."; | |||||
| std::vector<std::string> strings = {"one", "two", "three", "four", "five", "six", "seven", "eight"}; | |||||
| TensorShape shape({static_cast<dsize_t>(strings.size())}); | |||||
| std::shared_ptr<Tensor> input = std::make_shared<Tensor>(strings, shape); | |||||
| std::shared_ptr<Tensor> output; | |||||
| std::unique_ptr<SlidingWindowOp> op(new SlidingWindowOp(3, 0)); | |||||
| Status s = op->Compute(input, &output); | |||||
| std::vector<std::string> out = {"one", "two", "three", "two", "three", "four", "three", "four", "five", | |||||
| "four", "five", "six", "five", "six", "seven", "six", "seven", "eight"}; | |||||
| std::shared_ptr<Tensor> expected = std::make_shared<Tensor>(out, TensorShape({6, 3})); | |||||
| ASSERT_TRUE(output->shape() == expected->shape()); | |||||
| ASSERT_TRUE(output->type() == expected->type()); | |||||
| MS_LOG(DEBUG) << *output << std::endl; | |||||
| MS_LOG(DEBUG) << *expected << std::endl; | |||||
| ASSERT_TRUE(*output == *expected); | |||||
| MS_LOG(INFO) << "MindDataTestSlidingWindowOp end."; | |||||
| } | |||||
| TEST_F(MindDataTestSlidingWindowOp, OutputShape) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestSlidingWindowOp->OutputShape."; | |||||
| std::vector<std::string> strings = {"one", "two", "three", "four", "five", "six", "seven", "eight"}; | |||||
| TensorShape shape({static_cast<dsize_t>(strings.size())}); | |||||
| std::shared_ptr<Tensor> input = std::make_shared<Tensor>(strings, shape); | |||||
| std::vector<TensorShape> input_shape = {input->shape()}; | |||||
| std::vector<TensorShape> output_shape = {TensorShape({})}; | |||||
| std::unique_ptr<SlidingWindowOp> op(new SlidingWindowOp(3, 0)); | |||||
| Status s = op->OutputShape(input_shape, output_shape); | |||||
| MS_LOG(DEBUG) << "input_shape" << input_shape[0]; | |||||
| MS_LOG(DEBUG) << "output_shape" << output_shape[0]; | |||||
| ASSERT_TRUE(output_shape[0] == TensorShape({6, 3})); | |||||
| MS_LOG(INFO) << "MindDataTestSlidingWindowOp end."; | |||||
| } | |||||
| @@ -0,0 +1,105 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """ | |||||
| Testing SlidingWindow in mindspore.dataset | |||||
| """ | |||||
| import numpy as np | |||||
| import mindspore.dataset as ds | |||||
| import mindspore.dataset.text as text | |||||
| def test_sliding_window_string(): | |||||
| """ test sliding_window with string type""" | |||||
| inputs = [["大", "家", "早", "上", "好"]] | |||||
| expect = np.array([['大', '家'], ['家', '早'], ['早', '上'], ['上', '好']]) | |||||
| dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False) | |||||
| dataset = dataset.map(input_columns=["text"], operations=text.SlidingWindow(2, 0)) | |||||
| result = [] | |||||
| for data in dataset.create_dict_iterator(): | |||||
| for i in range(data['text'].shape[0]): | |||||
| result.append([]) | |||||
| for j in range(data['text'].shape[1]): | |||||
| result[i].append(data['text'][i][j].decode('utf8')) | |||||
| result = np.array(result) | |||||
| np.testing.assert_array_equal(result, expect) | |||||
| def test_sliding_window_number(): | |||||
| inputs = [1] | |||||
| expect = np.array([[1]]) | |||||
| def gen(nums): | |||||
| yield (np.array(nums),) | |||||
| dataset = ds.GeneratorDataset(gen(inputs), column_names=["number"]) | |||||
| dataset = dataset.map(input_columns=["number"], operations=text.SlidingWindow(1, -1)) | |||||
| for data in dataset.create_dict_iterator(): | |||||
| np.testing.assert_array_equal(data['number'], expect) | |||||
| def test_sliding_window_big_width(): | |||||
| inputs = [[1, 2, 3, 4, 5]] | |||||
| expect = np.array([]) | |||||
| dataset = ds.NumpySlicesDataset(inputs, column_names=["number"], shuffle=False) | |||||
| dataset = dataset.map(input_columns=["number"], operations=text.SlidingWindow(30, 0)) | |||||
| for data in dataset.create_dict_iterator(): | |||||
| np.testing.assert_array_equal(data['number'], expect) | |||||
| def test_sliding_window_exception(): | |||||
| try: | |||||
| _ = text.SlidingWindow(0, 0) | |||||
| assert False | |||||
| except ValueError: | |||||
| pass | |||||
| try: | |||||
| _ = text.SlidingWindow("1", 0) | |||||
| assert False | |||||
| except TypeError: | |||||
| pass | |||||
| try: | |||||
| _ = text.SlidingWindow(1, "0") | |||||
| assert False | |||||
| except TypeError: | |||||
| pass | |||||
| try: | |||||
| inputs = [[1, 2, 3, 4, 5]] | |||||
| dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False) | |||||
| dataset = dataset.map(input_columns=["text"], operations=text.SlidingWindow(3, -100)) | |||||
| for _ in dataset.create_dict_iterator(): | |||||
| pass | |||||
| assert False | |||||
| except RuntimeError as e: | |||||
| assert "axis supports 0 or -1 only for now." in str(e) | |||||
| try: | |||||
| inputs = ["aa", "bb", "cc"] | |||||
| dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False) | |||||
| dataset = dataset.map(input_columns=["text"], operations=text.SlidingWindow(2, 0)) | |||||
| for _ in dataset.create_dict_iterator(): | |||||
| pass | |||||
| assert False | |||||
| except RuntimeError as e: | |||||
| assert "SlidingWindosOp supports 1D Tensors only for now." in str(e) | |||||
| if __name__ == '__main__': | |||||
| test_sliding_window_string() | |||||
| test_sliding_window_number() | |||||
| test_sliding_window_big_width() | |||||
| test_sliding_window_exception() | |||||