diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 7995c81491..4dd75eb776 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -17,6 +17,7 @@ #include "dataset/api/de_pipeline.h" #include "dataset/kernels/no_op.h" +#include "dataset/kernels/data/concatenate_op.h" #include "dataset/kernels/data/one_hot_op.h" #include "dataset/kernels/image/center_crop_op.h" #include "dataset/kernels/image/cut_out_op.h" @@ -434,6 +435,11 @@ void bindTensorOps2(py::module *m) { *m, "TruncateSequencePairOp", "Tensor operation to truncate two tensors to a max_length") .def(py::init()); + (void)py::class_>(*m, "ConcatenateOp", + "Tensor operation concatenate tensors.") + .def(py::init, std::shared_ptr>(), py::arg("axis"), + py::arg("prepend").none(true), py::arg("append").none(true)); + (void)py::class_>( *m, "RandomRotationOp", "Tensor operation to apply RandomRotation." diff --git a/mindspore/ccsrc/dataset/core/tensor.cc b/mindspore/ccsrc/dataset/core/tensor.cc index 074603f833..a3c3e4533c 100644 --- a/mindspore/ccsrc/dataset/core/tensor.cc +++ b/mindspore/ccsrc/dataset/core/tensor.cc @@ -589,11 +589,13 @@ Status Tensor::StartAddrOfIndex(std::vector ind, uchar **start_addr_of_ if (type() == DataType::DE_STRING) { RETURN_STATUS_UNEXPECTED("StartAddrOfIndex does not support string tensors yet."); } + dsize_t flat_ind; std::vector t_shape = shape().AsVector(); std::vector r(t_shape.begin() + ind.size(), t_shape.end()); *remaining = TensorShape(r); ind.resize(this->Rank(), 0); // same as -> while (ind.size() < this->Rank()) ind.push_back(0); + RETURN_IF_NOT_OK(shape_.ToFlatIndex(ind, &flat_ind)); // check if GetBuffer() returns null, we should flag this as an error, this sanity check will only // be true is the tensor failed to allocate memory. @@ -634,6 +636,39 @@ Status Tensor::InsertTensor(const std::vector &ind, const std::shared_p } } +Status Tensor::Concatenate(const std::vector &index, const std::shared_ptr &tensor) { + std::string err_msg; + err_msg += (index.size() != 1) ? "[Tensor] only supports 1d concatenation \n" : ""; + err_msg += (type() == DataType::DE_STRING) ? "[Tensor] Cannot batch tensors of type string\n" : ""; + err_msg += (!shape().known() || !tensor->shape().known()) ? "[Tensor] unknown shape\n" : ""; + + err_msg += + (index.at(0) + tensor->shape().NumOfElements() > this->shape().NumOfElements()) ? "[Tensor] incorrect index\n" : ""; + err_msg += tensor->type().SizeInBytes() != this->type().SizeInBytes() ? "[Tensor] incorrect datatype\n" : ""; + uchar *start_addr_of_ind = nullptr; + + TensorShape remaining_shape = tensor->shape(); + StartAddrOfIndex(index, &start_addr_of_ind, &remaining_shape); + err_msg += (start_addr_of_ind == nullptr) ? "Failed to create memory for Tensor.\n" : ""; + + if (!err_msg.empty()) { + MS_LOG(DEBUG) << "Insert tensor message: " << err_msg; + + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + int ret_code = + memcpy_s(start_addr_of_ind, tensor->SizeInBytes(), tensor->GetMutableBuffer(), tensor->SizeInBytes()); + + if (ret_code == 0) { + return Status::OK(); + } else { + err_msg += "[Tensor] error in memcpy_s when inserting tensor\n"; + MS_LOG(DEBUG) << "Tensor message: " << err_msg; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } +} + Status Tensor::ExpandDim(const dsize_t &axis) { if (axis > Rank()) { std::string err = "Axis is out of bound"; diff --git a/mindspore/ccsrc/dataset/core/tensor.h b/mindspore/ccsrc/dataset/core/tensor.h index ad503a9290..0aec84f77b 100644 --- a/mindspore/ccsrc/dataset/core/tensor.h +++ b/mindspore/ccsrc/dataset/core/tensor.h @@ -372,6 +372,9 @@ class Tensor { static Status GetBufferInfo(Tensor &t, py::buffer_info *out); + // Concatenate based on given tensor, can fill in current tensor with a smaller one, unlike InsertTensor + Status Concatenate(const std::vector &index, const std::shared_ptr &input); + // TensorIterator is a linear iterator that can be used to iterate over the elements of the Tensor // The order elements is as the memory layout (i.e., row-major) [[1,2,3],[4,5,6] --> 1,2,3,4,5,6 // @tparam T type of values in the Tensor Iterator diff --git a/mindspore/ccsrc/dataset/core/tensor_shape.h b/mindspore/ccsrc/dataset/core/tensor_shape.h index 27fce91aec..c83e43cd7d 100644 --- a/mindspore/ccsrc/dataset/core/tensor_shape.h +++ b/mindspore/ccsrc/dataset/core/tensor_shape.h @@ -94,7 +94,7 @@ class TensorShape { // @return TensorShape PrependDim(dsize_t dim) const; - // Insert a new dim at the end of the shape. For example, <2,4> --> PrependDim(4) --> <2,4,4> + // Insert a new dim at the end of the shape. For example, <2,4> --> AppendDim(4) --> <2,4,4> // @param dim // @return TensorShape AppendDim(dsize_t dim) const; diff --git a/mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt b/mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt index 05ca7b360a..1df952f351 100644 --- a/mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt @@ -1,12 +1,13 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) add_library(kernels-data OBJECT - data_utils.cc - one_hot_op.cc - pad_end_op.cc - type_cast_op.cc - to_float16_op.cc - fill_op.cc - slice_op.cc - mask_op.cc - ) + data_utils.cc + one_hot_op.cc + pad_end_op.cc + type_cast_op.cc + to_float16_op.cc + fill_op.cc + slice_op.cc + mask_op.cc + concatenate_op.cc + ) diff --git a/mindspore/ccsrc/dataset/kernels/data/concatenate_op.cc b/mindspore/ccsrc/dataset/kernels/data/concatenate_op.cc new file mode 100644 index 0000000000..87115fd3ce --- /dev/null +++ b/mindspore/ccsrc/dataset/kernels/data/concatenate_op.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/kernels/data/concatenate_op.h" + +#include "dataset/core/tensor.h" +#include "dataset/kernels/data/data_utils.h" +#include "dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { + +Status ConcatenateOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + RETURN_IF_NOT_OK(Concatenate(input, output, axis_, prepend_, append_)); + return Status::OK(); +} + +Status ConcatenateOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + + std::vector inputs_copy; + inputs_copy.push_back(inputs[0].Squeeze()); + + CHECK_FAIL_RETURN_UNEXPECTED(inputs.at(0).Rank() == 1, "Only 1D input tensors supported"); + + outputs.clear(); + dsize_t output_shape = 0; + output_shape = output_shape + inputs.at(0).NumOfElements(); + if (prepend_ != nullptr) { + CHECK_FAIL_RETURN_UNEXPECTED(prepend_->shape().Rank() == 1, "Only 1D prepend tensors supported"); + output_shape = output_shape + prepend_->shape().NumOfElements(); + } + if (append_ != nullptr) { + CHECK_FAIL_RETURN_UNEXPECTED(append_->shape().Rank() == 1, "Only 1D append tensors supported"); + output_shape = output_shape + append_->shape().NumOfElements(); + } + + outputs.emplace_back(std::vector{output_shape}); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/concatenate_op.h b/mindspore/ccsrc/dataset/kernels/data/concatenate_op.h new file mode 100644 index 0000000000..4e4c7ad4e0 --- /dev/null +++ b/mindspore/ccsrc/dataset/kernels/data/concatenate_op.h @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_KERNELS_DATA_CONCATENATE_OP_H_ +#define DATASET_KERNELS_DATA_CONCATENATE_OP_H_ + +#include +#include +#include + +#include "dataset/core/tensor.h" +#include "dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { + +class ConcatenateOp : public TensorOp { + public: + /// Constructor to ConcatenateOp. + /// @param int8_t axis - axis to concatenate tensors along. + /// @param std::shared_ptr prepend - prepend tensor. + /// @param std::shared_ptr append -append tensor. + explicit ConcatenateOp(int8_t axis, std::shared_ptr prepend, std::shared_ptr append) + : axis_(axis), prepend_(prepend), append_(append) {} + + ~ConcatenateOp() override = default; + + /// Print method to see which tensor Op this is. + /// @param std::ostream &out - output stream object. + void Print(std::ostream &out) const override { out << "ConcatenateOp"; } + + /// Compute method allowing multiple tensors as inputs + /// @param TensorRow &input - input tensor rows + /// @param TensorRow *output - output tensor rows + Status Compute(const TensorRow &input, TensorRow *output) override; + + /// Compute tensor output shape + /// @param std::vector &inputs - vector of input tensor shapes + /// @param std::vector &inputs, std::vector &outputs) override; + + /// Number of inputs the tensor operation accepts + uint32_t NumInput() override { return 0; } + + private: + int8_t axis_; + std::shared_ptr prepend_; + std::shared_ptr append_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CONCATENATE_OP_H diff --git a/mindspore/ccsrc/dataset/kernels/data/data_utils.cc b/mindspore/ccsrc/dataset/kernels/data/data_utils.cc index cbbfa08e99..5a20926618 100644 --- a/mindspore/ccsrc/dataset/kernels/data/data_utils.cc +++ b/mindspore/ccsrc/dataset/kernels/data/data_utils.cc @@ -555,5 +555,80 @@ Status Mask(const std::shared_ptr &input, std::shared_ptr *outpu } return Status::OK(); } + +Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr prepend, + std::shared_ptr append) { + CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Rank() == 1, "Only 1D tensors supported"); + CHECK_FAIL_RETURN_UNEXPECTED(axis == 0 || axis == -1, "Only concatenation along the last dimension supported"); + + Tensor::HandleNeg(axis, input[0]->shape().Rank()); + CHECK_FAIL_RETURN_UNEXPECTED(axis == 0, "Only axis=0 is supported"); + + std::shared_ptr out; + if (prepend != nullptr) { + CHECK_FAIL_RETURN_UNEXPECTED(prepend->shape().Rank() == 1, "Only 1D tensors supported"); + RETURN_IF_NOT_OK(ConcatenateHelper(prepend, &out, axis, input[0])); + } else { + out = input[0]; + } + for (dsize_t i = 1; i < input.size(); i++) { + std::shared_ptr out_t; + CHECK_FAIL_RETURN_UNEXPECTED(input[i]->shape().Rank() == 1, "Only 1D tensors supported"); + RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, input[i])); + out = out_t; + } + std::shared_ptr out_t; + if (append != nullptr) { + CHECK_FAIL_RETURN_UNEXPECTED(append->shape().Rank() == 1, "Only 1D tensors supported"); + RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, append)); + } else { + out_t = out; + } + output->push_back(out_t); + + return Status::OK(); +} + +Status ConcatenateHelper(const std::shared_ptr &input, std::shared_ptr *output, int8_t axis, + std::shared_ptr append) { + CHECK_FAIL_RETURN_UNEXPECTED(input->type() == append->type(), "Tensor types do not match"); + + TensorShape t({}); + + for (dsize_t i = 0; i < input->shape().Rank(); i++) { + if (i != axis) { + t = t.AppendDim(input->shape()[i]); + } else { + dsize_t new_shape = input->shape()[i] + append->shape()[i]; + + t = t.AppendDim(new_shape); + } + } + std::shared_ptr out; + + if (input->type().IsNumeric()) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, t, input->type())); + + RETURN_IF_NOT_OK(out->Concatenate({0}, input)); + RETURN_IF_NOT_OK(out->Concatenate({input->shape()[0]}, append)); + *output = out; + } else { + std::vector strings; + + auto itr = input->begin(); + for (; itr != input->end(); itr++) { + strings.emplace_back(*itr); + } + itr = append->begin(); + for (; itr != append->end(); itr++) { + strings.emplace_back(*itr); + } + RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, t)); + + *output = out; + } + + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/data_utils.h b/mindspore/ccsrc/dataset/kernels/data/data_utils.h index 4dec0f0470..6034e2a0eb 100644 --- a/mindspore/ccsrc/dataset/kernels/data/data_utils.h +++ b/mindspore/ccsrc/dataset/kernels/data/data_utils.h @@ -23,6 +23,7 @@ #include "dataset/core/cv_tensor.h" #include "dataset/core/data_type.h" #include "dataset/core/tensor.h" +#include "dataset/core/tensor_row.h" namespace mindspore { namespace dataset { @@ -148,6 +149,14 @@ Status MaskHelper(const std::shared_ptr &input, const std::shared_ptr &input, std::shared_ptr *output, const std::shared_ptr &value, RelationalOp op); + +Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr prepend, + std::shared_ptr append); + +// helper for concat, always append to the input, and pass that to the output +Status ConcatenateHelper(const std::shared_ptr &input, std::shared_ptr *output, int8_t axis, + std::shared_ptr append); + } // namespace dataset } // namespace mindspore diff --git a/mindspore/dataset/transforms/c_transforms.py b/mindspore/dataset/transforms/c_transforms.py index 3167c642f4..903315ef0b 100644 --- a/mindspore/dataset/transforms/c_transforms.py +++ b/mindspore/dataset/transforms/c_transforms.py @@ -16,13 +16,13 @@ This module c_transforms provides common operations, including OneHotOp and TypeCast. """ from enum import IntEnum +import numpy as np import mindspore.common.dtype as mstype import mindspore._c_dataengine as cde -import numpy as np - -from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op, check_mask_op, check_pad_end +from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op, check_mask_op, \ + check_pad_end, check_concat_type from ..core.datatypes import mstype_to_detype @@ -187,3 +187,19 @@ class PadEnd(cde.PadEndOp): if pad_value is not None: pad_value = cde.Tensor(np.array(pad_value)) super().__init__(cde.TensorShape(pad_shape), pad_value) + + +class Concatenate(cde.ConcatenateOp): + """ + Tensor operation to prepend and append to a tensor. + + Args: + axis (int, optional): axis to concatenate the tensors along (Default=0). + prepend (np.array, optional): numpy array to be prepended to the already concatenated tensors (Default=None). + append (np.array, optional): numpy array to be appended to the already concatenated tensors (Default=None). + """ + + @check_concat_type + def __init__(self, axis=0, prepend=None, append=None): + # add some validations here later + super().__init__(axis, prepend, append) diff --git a/mindspore/dataset/transforms/validators.py b/mindspore/dataset/transforms/validators.py index 8fa701ace6..ba43228418 100644 --- a/mindspore/dataset/transforms/validators.py +++ b/mindspore/dataset/transforms/validators.py @@ -15,7 +15,9 @@ """Validators for TensorOps. """ from functools import wraps +import numpy as np +import mindspore._c_dataengine as cde from mindspore._c_expression import typing # POS_INT_MIN is used to limit values from starting from 0 @@ -230,10 +232,11 @@ def check_mask_op(method): if operator is None: raise ValueError("operator is not provided.") + + from .c_transforms import Relational if constant is None: raise ValueError("constant is not provided.") - from .c_transforms import Relational if not isinstance(operator, Relational): raise TypeError("operator is not a Relational operator enum.") @@ -282,3 +285,46 @@ def check_pad_end(method): return method(self, **kwargs) return new_method + + +def check_concat_type(method): + """Wrapper method to check the parameters of concatenation op.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + axis, prepend, append = (list(args) + 3 * [None])[:3] + if "prepend" in kwargs: + prepend = kwargs.get("prepend") + if "append" in kwargs: + append = kwargs.get("append") + if "axis" in kwargs: + axis = kwargs.get("axis") + + if not isinstance(axis, (type(None), int)): + raise TypeError("axis type is not valid, must be None or an integer.") + + if isinstance(axis, type(None)): + axis = 0 + + if axis not in (None, 0, -1): + raise ValueError("only 1D concatenation supported.") + + if not isinstance(prepend, (type(None), np.ndarray)): + raise ValueError("prepend type is not valid, must be None for no prepend tensor or a numpy array.") + + if not isinstance(append, (type(None), np.ndarray)): + raise ValueError("append type is not valid, must be None for no append tensor or a numpy array.") + + if isinstance(prepend, np.ndarray): + prepend = cde.Tensor(prepend) + + if isinstance(append, np.ndarray): + append = cde.Tensor(append) + + kwargs["axis"] = axis + kwargs["prepend"] = prepend + kwargs["append"] = append + + return method(self, **kwargs) + + return new_method diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index 68ce6186d1..317f9d67c3 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -1,83 +1,84 @@ include(GoogleTest) SET(DE_UT_SRCS - common/common.cc - common/cvop_common.cc - batch_op_test.cc - bit_functions_test.cc - storage_container_test.cc - treap_test.cc - interrupt_test.cc - image_folder_op_test.cc - buddy_test.cc - arena_test.cc - btree_test.cc - center_crop_op_test.cc - channel_swap_test.cc - circular_pool_test.cc - client_config_test.cc - connector_test.cc - datatype_test.cc - decode_op_test.cc - execution_tree_test.cc - global_context_test.cc - main_test.cc - map_op_test.cc - mind_record_op_test.cc - memory_pool_test.cc - normalize_op_test.cc - one_hot_op_test.cc - pad_end_op_test.cc - path_test.cc - project_op_test.cc - queue_test.cc - random_crop_op_test.cc - random_crop_decode_resize_op_test.cc - random_crop_and_resize_op_test.cc - random_color_adjust_op_test.cc - random_horizontal_flip_op_test.cc - random_resize_op_test.cc - random_rotation_op_test.cc - random_vertical_flip_op_test.cc - rename_op_test.cc - repeat_op_test.cc - skip_op_test.cc - rescale_op_test.cc - resize_bilinear_op_test.cc - resize_op_test.cc - shuffle_op_test.cc - stand_alone_samplers_test.cc - status_test.cc - storage_op_test.cc - task_manager_test.cc - tensor_test.cc - tensor_string_test.cc - tensorshape_test.cc - tfReader_op_test.cc - to_float16_op_test.cc - type_cast_op_test.cc - zip_op_test.cc - random_resize_op_test.cc - subset_random_sampler_test.cc - weighted_random_sampler_test.cc - mnist_op_test.cc - manifest_op_test.cc - voc_op_test.cc - cifar_op_test.cc - celeba_op_test.cc - take_op_test.cc - clue_op_test.cc - text_file_op_test.cc - filter_op_test.cc - concat_op_test.cc - jieba_tokenizer_op_test.cc - tokenizer_op_test.cc - gnn_graph_test.cc - coco_op_test.cc - fill_op_test.cc - mask_test.cc - trucate_pair_test.cc - ) + common/common.cc + common/cvop_common.cc + batch_op_test.cc + bit_functions_test.cc + storage_container_test.cc + treap_test.cc + interrupt_test.cc + image_folder_op_test.cc + buddy_test.cc + arena_test.cc + btree_test.cc + center_crop_op_test.cc + channel_swap_test.cc + circular_pool_test.cc + client_config_test.cc + connector_test.cc + datatype_test.cc + decode_op_test.cc + execution_tree_test.cc + global_context_test.cc + main_test.cc + map_op_test.cc + mind_record_op_test.cc + memory_pool_test.cc + normalize_op_test.cc + one_hot_op_test.cc + pad_end_op_test.cc + path_test.cc + project_op_test.cc + queue_test.cc + random_crop_op_test.cc + random_crop_decode_resize_op_test.cc + random_crop_and_resize_op_test.cc + random_color_adjust_op_test.cc + random_horizontal_flip_op_test.cc + random_resize_op_test.cc + random_rotation_op_test.cc + random_vertical_flip_op_test.cc + rename_op_test.cc + repeat_op_test.cc + skip_op_test.cc + rescale_op_test.cc + resize_bilinear_op_test.cc + resize_op_test.cc + shuffle_op_test.cc + stand_alone_samplers_test.cc + status_test.cc + storage_op_test.cc + task_manager_test.cc + tensor_test.cc + tensor_string_test.cc + tensorshape_test.cc + tfReader_op_test.cc + to_float16_op_test.cc + type_cast_op_test.cc + zip_op_test.cc + random_resize_op_test.cc + subset_random_sampler_test.cc + weighted_random_sampler_test.cc + mnist_op_test.cc + manifest_op_test.cc + voc_op_test.cc + cifar_op_test.cc + celeba_op_test.cc + take_op_test.cc + clue_op_test.cc + text_file_op_test.cc + filter_op_test.cc + concat_op_test.cc + jieba_tokenizer_op_test.cc + tokenizer_op_test.cc + gnn_graph_test.cc + coco_op_test.cc + fill_op_test.cc + mask_test.cc + trucate_pair_test.cc + concatenate_op_test.cc + ) add_executable(de_ut_tests ${DE_UT_SRCS}) @@ -88,8 +89,8 @@ target_link_libraries(de_ut_tests PRIVATE _c_dataengine pybind11::embed ${GTEST_ gtest_discover_tests(de_ut_tests WORKING_DIRECTORY ${Project_DIR}/tests/dataset) install(TARGETS de_ut_tests - RUNTIME DESTINATION test) + RUNTIME DESTINATION test) # For internal testing only. install(DIRECTORY ${Project_DIR}/tests/dataset/data/ - DESTINATION test/data) + DESTINATION test/data) diff --git a/tests/ut/cpp/dataset/concatenate_op_test.cc b/tests/ut/cpp/dataset/concatenate_op_test.cc new file mode 100644 index 0000000000..1ceedbac38 --- /dev/null +++ b/tests/ut/cpp/dataset/concatenate_op_test.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common.h" +#include "dataset/kernels/data/concatenate_op.h" +#include "utils/log_adapter.h" + +using namespace mindspore::dataset; +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::INFO; + +class MindDataTestConcatenateOp : public UT::Common { + protected: + MindDataTestConcatenateOp() {} +}; + +TEST_F(MindDataTestConcatenateOp, TestOp) { + MS_LOG(INFO) << "Doing MindDataTestConcatenate-TestOp."; + uint64_t labels[3] = {1, 1, 2}; + TensorShape shape({3}); + std::shared_ptr input = + std::make_shared(shape, DataType(DataType::DE_UINT64), reinterpret_cast(labels)); + + uint64_t append_labels[3] = {4, 4, 4}; + std::shared_ptr append = + std::make_shared(shape, DataType(DataType::DE_UINT64), reinterpret_cast(append_labels)); + + std::shared_ptr output; + std::unique_ptr op(new ConcatenateOp(0, nullptr, append)); + TensorRow in; + in.push_back(input); + TensorRow out_row; + Status s = op->Compute(in, &out_row); + uint64_t out[6] = {1, 1, 2, 4, 4, 4}; + + std::shared_ptr expected = + std::make_shared(TensorShape{6}, DataType(DataType::DE_UINT64), reinterpret_cast(out)); + output = out_row[0]; + EXPECT_TRUE(s.IsOk()); + ASSERT_TRUE(output->shape() == expected->shape()); + ASSERT_TRUE(output->type() == expected->type()); + MS_LOG(DEBUG) << *output << std::endl; + MS_LOG(DEBUG) << *expected << std::endl; + + ASSERT_TRUE(*output == *expected); + + // std::vector inputs = {TensorShape({3})}; + // std::vector outputs = {}; + // s = op->OutputShape(inputs, outputs); + // EXPECT_TRUE(s.IsOk()); + // ASSERT_TRUE(outputs[0] == TensorShape{6}); + // MS_LOG(INFO) << "MindDataTestConcatenateOp-TestOp end."; +} diff --git a/tests/ut/cpp/dataset/tensor_test.cc b/tests/ut/cpp/dataset/tensor_test.cc index d47d22fb9c..1e75880a35 100644 --- a/tests/ut/cpp/dataset/tensor_test.cc +++ b/tests/ut/cpp/dataset/tensor_test.cc @@ -141,7 +141,6 @@ TEST_F(MindDataTestTensorDE, InsertTensor) { std::shared_ptr t4; Tensor::CreateTensor(&t4, z, TensorShape({2, 3})); - ASSERT_EQ(*t == *t4, true); std::shared_ptr t5; @@ -407,3 +406,30 @@ TEST_F(MindDataTestTensorDE, TensorSlice) { t->Slice(&t2, std::vector{0, 1, 2, 3, 4}); ASSERT_EQ(*t2, *t); } + +TEST_F(MindDataTestTensorDE, TensorConcatenate) { + std::vector values1 = {1, 2, 3, 0, 0, 0}; + std::vector values2 = {4, 5, 6}; + std::vector expected = {1, 2, 3, 4, 5, 6}; + + std::shared_ptr t1; + Tensor::CreateTensor(&t1, values1); + + std::shared_ptr t2; + Tensor::CreateTensor(&t2, values2); + + std::shared_ptr out; + Tensor::CreateTensor(&out, expected); + Status s = t1->Concatenate({3}, t2); + EXPECT_TRUE(s.IsOk()); + + auto i = out->begin(); + auto j = t1->begin(); + for (; i != out->end(); i++, j++) { + ASSERT_TRUE(*i == *j); + } + + // should fail if the concatenated vector is too large + s = t1->Concatenate({5}, t2); + EXPECT_FALSE(s.IsOk()); +} diff --git a/tests/ut/python/dataset/test_concatenate_op.py b/tests/ut/python/dataset/test_concatenate_op.py new file mode 100644 index 0000000000..c31ab8efb7 --- /dev/null +++ b/tests/ut/python/dataset/test_concatenate_op.py @@ -0,0 +1,175 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing concatenate op +""" + +import numpy as np +import pytest + +import mindspore.dataset as ds +import mindspore.dataset.transforms.c_transforms as data_trans + + +def test_concatenate_op_all(): + def gen(): + yield (np.array([5., 6., 7., 8.], dtype=np.float),) + + prepend_tensor = np.array([1.4, 2., 3., 4., 4.5], dtype=np.float) + append_tensor = np.array([9., 10.3, 11., 12.], dtype=np.float) + data = ds.GeneratorDataset(gen, column_names=["col"]) + concatenate_op = data_trans.Concatenate(0, prepend_tensor, append_tensor) + data = data.map(input_columns=["col"], operations=concatenate_op) + expected = np.array([1.4, 2., 3., 4., 4.5, 5., 6., 7., 8., 9., 10.3, + 11., 12.]) + for data_row in data: + np.testing.assert_array_equal(data_row[0], expected) + + +def test_concatenate_op_none(): + def gen(): + yield (np.array([5., 6., 7., 8.], dtype=np.float),) + + data = ds.GeneratorDataset(gen, column_names=["col"]) + concatenate_op = data_trans.Concatenate() + + data = data.map(input_columns=["col"], operations=concatenate_op) + for data_row in data: + np.testing.assert_array_equal(data_row[0], np.array([5., 6., 7., 8.], dtype=np.float)) + + +def test_concatenate_op_string(): + def gen(): + yield (np.array(["ss", "ad"], dtype='S'),) + + prepend_tensor = np.array(["dw", "df"], dtype='S') + append_tensor = np.array(["dwsdf", "df"], dtype='S') + data = ds.GeneratorDataset(gen, column_names=["col"]) + concatenate_op = data_trans.Concatenate(0, prepend_tensor, append_tensor) + + data = data.map(input_columns=["col"], operations=concatenate_op) + expected = np.array(["dw", "df", "ss", "ad", "dwsdf", "df"], dtype='S') + for data_row in data: + np.testing.assert_array_equal(data_row[0], expected) + + +def test_concatenate_op_multi_input_string(): + prepend_tensor = np.array(["dw", "df"], dtype='S') + append_tensor = np.array(["dwsdf", "df"], dtype='S') + + data = ([["1", "2", "d"]], [["3", "4", "e"]]) + data = ds.NumpySlicesDataset(data, column_names=["col1", "col2"]) + + concatenate_op = data_trans.Concatenate(0, prepend=prepend_tensor, append=append_tensor) + + data = data.map(input_columns=["col1", "col2"], columns_order=["out1"], output_columns=["out1"], + operations=concatenate_op) + expected = np.array(["dw", "df", "1", "2", "d", "3", "4", "e", "dwsdf", "df"], dtype='S') + for data_row in data: + np.testing.assert_array_equal(data_row[0], expected) + + +def test_concatenate_op_multi_input_numeric(): + prepend_tensor = np.array([3, 5]) + + data = ([[1, 2]], [[3, 4]]) + data = ds.NumpySlicesDataset(data, column_names=["col1", "col2"]) + + concatenate_op = data_trans.Concatenate(0, prepend=prepend_tensor) + + data = data.map(input_columns=["col1", "col2"], columns_order=["out1"], output_columns=["out1"], + operations=concatenate_op) + expected = np.array([3, 5, 1, 2, 3, 4]) + for data_row in data: + np.testing.assert_array_equal(data_row[0], expected) + + +def test_concatenate_op_type_mismatch(): + def gen(): + yield (np.array([3, 4], dtype=np.float),) + + prepend_tensor = np.array(["ss", "ad"], dtype='S') + data = ds.GeneratorDataset(gen, column_names=["col"]) + concatenate_op = data_trans.Concatenate(0, prepend_tensor) + + data = data.map(input_columns=["col"], operations=concatenate_op) + with pytest.raises(RuntimeError) as error_info: + for _ in data: + pass + assert "Tensor types do not match" in repr(error_info.value) + + +def test_concatenate_op_type_mismatch2(): + def gen(): + yield (np.array(["ss", "ad"], dtype='S'),) + + prepend_tensor = np.array([3, 5], dtype=np.float) + data = ds.GeneratorDataset(gen, column_names=["col"]) + concatenate_op = data_trans.Concatenate(0, prepend_tensor) + + data = data.map(input_columns=["col"], operations=concatenate_op) + with pytest.raises(RuntimeError) as error_info: + for _ in data: + pass + assert "Tensor types do not match" in repr(error_info.value) + + +def test_concatenate_op_incorrect_dim(): + def gen(): + yield (np.array([["ss", "ad"], ["ss", "ad"]], dtype='S'),) + + prepend_tensor = np.array([3, 5], dtype=np.float) + concatenate_op = data_trans.Concatenate(0, prepend_tensor) + data = ds.GeneratorDataset(gen, column_names=["col"]) + + data = data.map(input_columns=["col"], operations=concatenate_op) + with pytest.raises(RuntimeError) as error_info: + for _ in data: + pass + assert "Only 1D tensors supported" in repr(error_info.value) + + +def test_concatenate_op_wrong_axis(): + with pytest.raises(ValueError) as error_info: + data_trans.Concatenate(2) + assert "only 1D concatenation supported." in repr(error_info.value) + + +def test_concatenate_op_incorrect_input_dim(): + def gen(): + yield (np.array(["ss", "ad"], dtype='S'),) + + prepend_tensor = np.array([["ss", "ad"], ["ss", "ad"]], dtype='S') + data = ds.GeneratorDataset(gen, column_names=["col"]) + concatenate_op = data_trans.Concatenate(0, prepend_tensor) + + data = data.map(input_columns=["col"], operations=concatenate_op) + with pytest.raises(RuntimeError) as error_info: + for _ in data: + pass + assert "Only 1D tensors supported" in repr(error_info.value) + + +if __name__ == "__main__": + test_concatenate_op_all() + test_concatenate_op_none() + test_concatenate_op_string() + test_concatenate_op_type_mismatch() + test_concatenate_op_type_mismatch2() + test_concatenate_op_incorrect_dim() + test_concatenate_op_incorrect_input_dim() + test_concatenate_op_multi_input_numeric() + test_concatenate_op_multi_input_string() + test_concatenate_op_wrong_axis()