remove graphengine changes concat op Truncate Pair concat_op remove graph engine changestags/v0.5.0-beta
| @@ -17,6 +17,7 @@ | |||||
| #include "dataset/api/de_pipeline.h" | #include "dataset/api/de_pipeline.h" | ||||
| #include "dataset/kernels/no_op.h" | #include "dataset/kernels/no_op.h" | ||||
| #include "dataset/kernels/data/concatenate_op.h" | |||||
| #include "dataset/kernels/data/one_hot_op.h" | #include "dataset/kernels/data/one_hot_op.h" | ||||
| #include "dataset/kernels/image/center_crop_op.h" | #include "dataset/kernels/image/center_crop_op.h" | ||||
| #include "dataset/kernels/image/cut_out_op.h" | #include "dataset/kernels/image/cut_out_op.h" | ||||
| @@ -434,6 +435,11 @@ void bindTensorOps2(py::module *m) { | |||||
| *m, "TruncateSequencePairOp", "Tensor operation to truncate two tensors to a max_length") | *m, "TruncateSequencePairOp", "Tensor operation to truncate two tensors to a max_length") | ||||
| .def(py::init<int64_t>()); | .def(py::init<int64_t>()); | ||||
| (void)py::class_<ConcatenateOp, TensorOp, std::shared_ptr<ConcatenateOp>>(*m, "ConcatenateOp", | |||||
| "Tensor operation concatenate tensors.") | |||||
| .def(py::init<int8_t, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>(), py::arg("axis"), | |||||
| py::arg("prepend").none(true), py::arg("append").none(true)); | |||||
| (void)py::class_<RandomRotationOp, TensorOp, std::shared_ptr<RandomRotationOp>>( | (void)py::class_<RandomRotationOp, TensorOp, std::shared_ptr<RandomRotationOp>>( | ||||
| *m, "RandomRotationOp", | *m, "RandomRotationOp", | ||||
| "Tensor operation to apply RandomRotation." | "Tensor operation to apply RandomRotation." | ||||
| @@ -589,11 +589,13 @@ Status Tensor::StartAddrOfIndex(std::vector<dsize_t> ind, uchar **start_addr_of_ | |||||
| if (type() == DataType::DE_STRING) { | if (type() == DataType::DE_STRING) { | ||||
| RETURN_STATUS_UNEXPECTED("StartAddrOfIndex does not support string tensors yet."); | RETURN_STATUS_UNEXPECTED("StartAddrOfIndex does not support string tensors yet."); | ||||
| } | } | ||||
| dsize_t flat_ind; | dsize_t flat_ind; | ||||
| std::vector<dsize_t> t_shape = shape().AsVector(); | std::vector<dsize_t> t_shape = shape().AsVector(); | ||||
| std::vector<dsize_t> r(t_shape.begin() + ind.size(), t_shape.end()); | std::vector<dsize_t> r(t_shape.begin() + ind.size(), t_shape.end()); | ||||
| *remaining = TensorShape(r); | *remaining = TensorShape(r); | ||||
| ind.resize(this->Rank(), 0); // same as -> while (ind.size() < this->Rank()) ind.push_back(0); | ind.resize(this->Rank(), 0); // same as -> while (ind.size() < this->Rank()) ind.push_back(0); | ||||
| RETURN_IF_NOT_OK(shape_.ToFlatIndex(ind, &flat_ind)); | RETURN_IF_NOT_OK(shape_.ToFlatIndex(ind, &flat_ind)); | ||||
| // check if GetBuffer() returns null, we should flag this as an error, this sanity check will only | // check if GetBuffer() returns null, we should flag this as an error, this sanity check will only | ||||
| // be true is the tensor failed to allocate memory. | // be true is the tensor failed to allocate memory. | ||||
| @@ -634,6 +636,39 @@ Status Tensor::InsertTensor(const std::vector<dsize_t> &ind, const std::shared_p | |||||
| } | } | ||||
| } | } | ||||
| Status Tensor::Concatenate(const std::vector<dsize_t> &index, const std::shared_ptr<Tensor> &tensor) { | |||||
| std::string err_msg; | |||||
| err_msg += (index.size() != 1) ? "[Tensor] only supports 1d concatenation \n" : ""; | |||||
| err_msg += (type() == DataType::DE_STRING) ? "[Tensor] Cannot batch tensors of type string\n" : ""; | |||||
| err_msg += (!shape().known() || !tensor->shape().known()) ? "[Tensor] unknown shape\n" : ""; | |||||
| err_msg += | |||||
| (index.at(0) + tensor->shape().NumOfElements() > this->shape().NumOfElements()) ? "[Tensor] incorrect index\n" : ""; | |||||
| err_msg += tensor->type().SizeInBytes() != this->type().SizeInBytes() ? "[Tensor] incorrect datatype\n" : ""; | |||||
| uchar *start_addr_of_ind = nullptr; | |||||
| TensorShape remaining_shape = tensor->shape(); | |||||
| StartAddrOfIndex(index, &start_addr_of_ind, &remaining_shape); | |||||
| err_msg += (start_addr_of_ind == nullptr) ? "Failed to create memory for Tensor.\n" : ""; | |||||
| if (!err_msg.empty()) { | |||||
| MS_LOG(DEBUG) << "Insert tensor message: " << err_msg; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } else { | |||||
| int ret_code = | |||||
| memcpy_s(start_addr_of_ind, tensor->SizeInBytes(), tensor->GetMutableBuffer(), tensor->SizeInBytes()); | |||||
| if (ret_code == 0) { | |||||
| return Status::OK(); | |||||
| } else { | |||||
| err_msg += "[Tensor] error in memcpy_s when inserting tensor\n"; | |||||
| MS_LOG(DEBUG) << "Tensor message: " << err_msg; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| } | |||||
| } | |||||
| Status Tensor::ExpandDim(const dsize_t &axis) { | Status Tensor::ExpandDim(const dsize_t &axis) { | ||||
| if (axis > Rank()) { | if (axis > Rank()) { | ||||
| std::string err = "Axis is out of bound"; | std::string err = "Axis is out of bound"; | ||||
| @@ -372,6 +372,9 @@ class Tensor { | |||||
| static Status GetBufferInfo(Tensor &t, py::buffer_info *out); | static Status GetBufferInfo(Tensor &t, py::buffer_info *out); | ||||
| // Concatenate based on given tensor, can fill in current tensor with a smaller one, unlike InsertTensor | |||||
| Status Concatenate(const std::vector<dsize_t> &index, const std::shared_ptr<Tensor> &input); | |||||
| // TensorIterator is a linear iterator that can be used to iterate over the elements of the Tensor | // TensorIterator is a linear iterator that can be used to iterate over the elements of the Tensor | ||||
| // The order elements is as the memory layout (i.e., row-major) [[1,2,3],[4,5,6] --> 1,2,3,4,5,6 | // The order elements is as the memory layout (i.e., row-major) [[1,2,3],[4,5,6] --> 1,2,3,4,5,6 | ||||
| // @tparam T type of values in the Tensor Iterator | // @tparam T type of values in the Tensor Iterator | ||||
| @@ -94,7 +94,7 @@ class TensorShape { | |||||
| // @return | // @return | ||||
| TensorShape PrependDim(dsize_t dim) const; | TensorShape PrependDim(dsize_t dim) const; | ||||
| // Insert a new dim at the end of the shape. For example, <2,4> --> PrependDim(4) --> <2,4,4> | |||||
| // Insert a new dim at the end of the shape. For example, <2,4> --> AppendDim(4) --> <2,4,4> | |||||
| // @param dim | // @param dim | ||||
| // @return | // @return | ||||
| TensorShape AppendDim(dsize_t dim) const; | TensorShape AppendDim(dsize_t dim) const; | ||||
| @@ -1,12 +1,13 @@ | |||||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | ||||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | ||||
| add_library(kernels-data OBJECT | add_library(kernels-data OBJECT | ||||
| data_utils.cc | |||||
| one_hot_op.cc | |||||
| pad_end_op.cc | |||||
| type_cast_op.cc | |||||
| to_float16_op.cc | |||||
| fill_op.cc | |||||
| slice_op.cc | |||||
| mask_op.cc | |||||
| ) | |||||
| data_utils.cc | |||||
| one_hot_op.cc | |||||
| pad_end_op.cc | |||||
| type_cast_op.cc | |||||
| to_float16_op.cc | |||||
| fill_op.cc | |||||
| slice_op.cc | |||||
| mask_op.cc | |||||
| concatenate_op.cc | |||||
| ) | |||||
| @@ -0,0 +1,55 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "dataset/kernels/data/concatenate_op.h" | |||||
| #include "dataset/core/tensor.h" | |||||
| #include "dataset/kernels/data/data_utils.h" | |||||
| #include "dataset/kernels/tensor_op.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| Status ConcatenateOp::Compute(const TensorRow &input, TensorRow *output) { | |||||
| IO_CHECK_VECTOR(input, output); | |||||
| RETURN_IF_NOT_OK(Concatenate(input, output, axis_, prepend_, append_)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status ConcatenateOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) { | |||||
| RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); | |||||
| std::vector<TensorShape> inputs_copy; | |||||
| inputs_copy.push_back(inputs[0].Squeeze()); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(inputs.at(0).Rank() == 1, "Only 1D input tensors supported"); | |||||
| outputs.clear(); | |||||
| dsize_t output_shape = 0; | |||||
| output_shape = output_shape + inputs.at(0).NumOfElements(); | |||||
| if (prepend_ != nullptr) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(prepend_->shape().Rank() == 1, "Only 1D prepend tensors supported"); | |||||
| output_shape = output_shape + prepend_->shape().NumOfElements(); | |||||
| } | |||||
| if (append_ != nullptr) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(append_->shape().Rank() == 1, "Only 1D append tensors supported"); | |||||
| output_shape = output_shape + append_->shape().NumOfElements(); | |||||
| } | |||||
| outputs.emplace_back(std::vector<dsize_t>{output_shape}); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,66 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DATASET_KERNELS_DATA_CONCATENATE_OP_H_ | |||||
| #define DATASET_KERNELS_DATA_CONCATENATE_OP_H_ | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "dataset/core/tensor.h" | |||||
| #include "dataset/kernels/tensor_op.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class ConcatenateOp : public TensorOp { | |||||
| public: | |||||
| /// Constructor to ConcatenateOp. | |||||
| /// @param int8_t axis - axis to concatenate tensors along. | |||||
| /// @param std::shared_ptr<Tensor> prepend - prepend tensor. | |||||
| /// @param std::shared_ptr<Tensor> append -append tensor. | |||||
| explicit ConcatenateOp(int8_t axis, std::shared_ptr<Tensor> prepend, std::shared_ptr<Tensor> append) | |||||
| : axis_(axis), prepend_(prepend), append_(append) {} | |||||
| ~ConcatenateOp() override = default; | |||||
| /// Print method to see which tensor Op this is. | |||||
| /// @param std::ostream &out - output stream object. | |||||
| void Print(std::ostream &out) const override { out << "ConcatenateOp"; } | |||||
| /// Compute method allowing multiple tensors as inputs | |||||
| /// @param TensorRow &input - input tensor rows | |||||
| /// @param TensorRow *output - output tensor rows | |||||
| Status Compute(const TensorRow &input, TensorRow *output) override; | |||||
| /// Compute tensor output shape | |||||
| /// @param std::vector<TensorShape> &inputs - vector of input tensor shapes | |||||
| /// @param std::vector<TensorShape< &outputs - vector of output tensor shapes | |||||
| Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override; | |||||
| /// Number of inputs the tensor operation accepts | |||||
| uint32_t NumInput() override { return 0; } | |||||
| private: | |||||
| int8_t axis_; | |||||
| std::shared_ptr<Tensor> prepend_; | |||||
| std::shared_ptr<Tensor> append_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CONCATENATE_OP_H | |||||
| @@ -555,5 +555,80 @@ Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outpu | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr<Tensor> prepend, | |||||
| std::shared_ptr<Tensor> append) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Rank() == 1, "Only 1D tensors supported"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(axis == 0 || axis == -1, "Only concatenation along the last dimension supported"); | |||||
| Tensor::HandleNeg(axis, input[0]->shape().Rank()); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(axis == 0, "Only axis=0 is supported"); | |||||
| std::shared_ptr<Tensor> out; | |||||
| if (prepend != nullptr) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(prepend->shape().Rank() == 1, "Only 1D tensors supported"); | |||||
| RETURN_IF_NOT_OK(ConcatenateHelper(prepend, &out, axis, input[0])); | |||||
| } else { | |||||
| out = input[0]; | |||||
| } | |||||
| for (dsize_t i = 1; i < input.size(); i++) { | |||||
| std::shared_ptr<Tensor> out_t; | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(input[i]->shape().Rank() == 1, "Only 1D tensors supported"); | |||||
| RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, input[i])); | |||||
| out = out_t; | |||||
| } | |||||
| std::shared_ptr<Tensor> out_t; | |||||
| if (append != nullptr) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(append->shape().Rank() == 1, "Only 1D tensors supported"); | |||||
| RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, append)); | |||||
| } else { | |||||
| out_t = out; | |||||
| } | |||||
| output->push_back(out_t); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status ConcatenateHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int8_t axis, | |||||
| std::shared_ptr<Tensor> append) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(input->type() == append->type(), "Tensor types do not match"); | |||||
| TensorShape t({}); | |||||
| for (dsize_t i = 0; i < input->shape().Rank(); i++) { | |||||
| if (i != axis) { | |||||
| t = t.AppendDim(input->shape()[i]); | |||||
| } else { | |||||
| dsize_t new_shape = input->shape()[i] + append->shape()[i]; | |||||
| t = t.AppendDim(new_shape); | |||||
| } | |||||
| } | |||||
| std::shared_ptr<Tensor> out; | |||||
| if (input->type().IsNumeric()) { | |||||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, t, input->type())); | |||||
| RETURN_IF_NOT_OK(out->Concatenate({0}, input)); | |||||
| RETURN_IF_NOT_OK(out->Concatenate({input->shape()[0]}, append)); | |||||
| *output = out; | |||||
| } else { | |||||
| std::vector<std::string> strings; | |||||
| auto itr = input->begin<std::string_view>(); | |||||
| for (; itr != input->end<std::string_view>(); itr++) { | |||||
| strings.emplace_back(*itr); | |||||
| } | |||||
| itr = append->begin<std::string_view>(); | |||||
| for (; itr != append->end<std::string_view>(); itr++) { | |||||
| strings.emplace_back(*itr); | |||||
| } | |||||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, t)); | |||||
| *output = out; | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include "dataset/core/cv_tensor.h" | #include "dataset/core/cv_tensor.h" | ||||
| #include "dataset/core/data_type.h" | #include "dataset/core/data_type.h" | ||||
| #include "dataset/core/tensor.h" | #include "dataset/core/tensor.h" | ||||
| #include "dataset/core/tensor_row.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -148,6 +149,14 @@ Status MaskHelper(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Te | |||||
| /// @return Status ok/error | /// @return Status ok/error | ||||
| Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::shared_ptr<Tensor> &value, | Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::shared_ptr<Tensor> &value, | ||||
| RelationalOp op); | RelationalOp op); | ||||
| Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr<Tensor> prepend, | |||||
| std::shared_ptr<Tensor> append); | |||||
| // helper for concat, always append to the input, and pass that to the output | |||||
| Status ConcatenateHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int8_t axis, | |||||
| std::shared_ptr<Tensor> append); | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,13 +16,13 @@ | |||||
| This module c_transforms provides common operations, including OneHotOp and TypeCast. | This module c_transforms provides common operations, including OneHotOp and TypeCast. | ||||
| """ | """ | ||||
| from enum import IntEnum | from enum import IntEnum | ||||
| import numpy as np | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| import mindspore._c_dataengine as cde | import mindspore._c_dataengine as cde | ||||
| import numpy as np | |||||
| from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op, check_mask_op, check_pad_end | |||||
| from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op, check_mask_op, \ | |||||
| check_pad_end, check_concat_type | |||||
| from ..core.datatypes import mstype_to_detype | from ..core.datatypes import mstype_to_detype | ||||
| @@ -187,3 +187,19 @@ class PadEnd(cde.PadEndOp): | |||||
| if pad_value is not None: | if pad_value is not None: | ||||
| pad_value = cde.Tensor(np.array(pad_value)) | pad_value = cde.Tensor(np.array(pad_value)) | ||||
| super().__init__(cde.TensorShape(pad_shape), pad_value) | super().__init__(cde.TensorShape(pad_shape), pad_value) | ||||
| class Concatenate(cde.ConcatenateOp): | |||||
| """ | |||||
| Tensor operation to prepend and append to a tensor. | |||||
| Args: | |||||
| axis (int, optional): axis to concatenate the tensors along (Default=0). | |||||
| prepend (np.array, optional): numpy array to be prepended to the already concatenated tensors (Default=None). | |||||
| append (np.array, optional): numpy array to be appended to the already concatenated tensors (Default=None). | |||||
| """ | |||||
| @check_concat_type | |||||
| def __init__(self, axis=0, prepend=None, append=None): | |||||
| # add some validations here later | |||||
| super().__init__(axis, prepend, append) | |||||
| @@ -15,7 +15,9 @@ | |||||
| """Validators for TensorOps. | """Validators for TensorOps. | ||||
| """ | """ | ||||
| from functools import wraps | from functools import wraps | ||||
| import numpy as np | |||||
| import mindspore._c_dataengine as cde | |||||
| from mindspore._c_expression import typing | from mindspore._c_expression import typing | ||||
| # POS_INT_MIN is used to limit values from starting from 0 | # POS_INT_MIN is used to limit values from starting from 0 | ||||
| @@ -230,10 +232,11 @@ def check_mask_op(method): | |||||
| if operator is None: | if operator is None: | ||||
| raise ValueError("operator is not provided.") | raise ValueError("operator is not provided.") | ||||
| from .c_transforms import Relational | |||||
| if constant is None: | if constant is None: | ||||
| raise ValueError("constant is not provided.") | raise ValueError("constant is not provided.") | ||||
| from .c_transforms import Relational | |||||
| if not isinstance(operator, Relational): | if not isinstance(operator, Relational): | ||||
| raise TypeError("operator is not a Relational operator enum.") | raise TypeError("operator is not a Relational operator enum.") | ||||
| @@ -282,3 +285,46 @@ def check_pad_end(method): | |||||
| return method(self, **kwargs) | return method(self, **kwargs) | ||||
| return new_method | return new_method | ||||
| def check_concat_type(method): | |||||
| """Wrapper method to check the parameters of concatenation op.""" | |||||
| @wraps(method) | |||||
| def new_method(self, *args, **kwargs): | |||||
| axis, prepend, append = (list(args) + 3 * [None])[:3] | |||||
| if "prepend" in kwargs: | |||||
| prepend = kwargs.get("prepend") | |||||
| if "append" in kwargs: | |||||
| append = kwargs.get("append") | |||||
| if "axis" in kwargs: | |||||
| axis = kwargs.get("axis") | |||||
| if not isinstance(axis, (type(None), int)): | |||||
| raise TypeError("axis type is not valid, must be None or an integer.") | |||||
| if isinstance(axis, type(None)): | |||||
| axis = 0 | |||||
| if axis not in (None, 0, -1): | |||||
| raise ValueError("only 1D concatenation supported.") | |||||
| if not isinstance(prepend, (type(None), np.ndarray)): | |||||
| raise ValueError("prepend type is not valid, must be None for no prepend tensor or a numpy array.") | |||||
| if not isinstance(append, (type(None), np.ndarray)): | |||||
| raise ValueError("append type is not valid, must be None for no append tensor or a numpy array.") | |||||
| if isinstance(prepend, np.ndarray): | |||||
| prepend = cde.Tensor(prepend) | |||||
| if isinstance(append, np.ndarray): | |||||
| append = cde.Tensor(append) | |||||
| kwargs["axis"] = axis | |||||
| kwargs["prepend"] = prepend | |||||
| kwargs["append"] = append | |||||
| return method(self, **kwargs) | |||||
| return new_method | |||||
| @@ -1,83 +1,84 @@ | |||||
| include(GoogleTest) | include(GoogleTest) | ||||
| SET(DE_UT_SRCS | SET(DE_UT_SRCS | ||||
| common/common.cc | |||||
| common/cvop_common.cc | |||||
| batch_op_test.cc | |||||
| bit_functions_test.cc | |||||
| storage_container_test.cc | |||||
| treap_test.cc | |||||
| interrupt_test.cc | |||||
| image_folder_op_test.cc | |||||
| buddy_test.cc | |||||
| arena_test.cc | |||||
| btree_test.cc | |||||
| center_crop_op_test.cc | |||||
| channel_swap_test.cc | |||||
| circular_pool_test.cc | |||||
| client_config_test.cc | |||||
| connector_test.cc | |||||
| datatype_test.cc | |||||
| decode_op_test.cc | |||||
| execution_tree_test.cc | |||||
| global_context_test.cc | |||||
| main_test.cc | |||||
| map_op_test.cc | |||||
| mind_record_op_test.cc | |||||
| memory_pool_test.cc | |||||
| normalize_op_test.cc | |||||
| one_hot_op_test.cc | |||||
| pad_end_op_test.cc | |||||
| path_test.cc | |||||
| project_op_test.cc | |||||
| queue_test.cc | |||||
| random_crop_op_test.cc | |||||
| random_crop_decode_resize_op_test.cc | |||||
| random_crop_and_resize_op_test.cc | |||||
| random_color_adjust_op_test.cc | |||||
| random_horizontal_flip_op_test.cc | |||||
| random_resize_op_test.cc | |||||
| random_rotation_op_test.cc | |||||
| random_vertical_flip_op_test.cc | |||||
| rename_op_test.cc | |||||
| repeat_op_test.cc | |||||
| skip_op_test.cc | |||||
| rescale_op_test.cc | |||||
| resize_bilinear_op_test.cc | |||||
| resize_op_test.cc | |||||
| shuffle_op_test.cc | |||||
| stand_alone_samplers_test.cc | |||||
| status_test.cc | |||||
| storage_op_test.cc | |||||
| task_manager_test.cc | |||||
| tensor_test.cc | |||||
| tensor_string_test.cc | |||||
| tensorshape_test.cc | |||||
| tfReader_op_test.cc | |||||
| to_float16_op_test.cc | |||||
| type_cast_op_test.cc | |||||
| zip_op_test.cc | |||||
| random_resize_op_test.cc | |||||
| subset_random_sampler_test.cc | |||||
| weighted_random_sampler_test.cc | |||||
| mnist_op_test.cc | |||||
| manifest_op_test.cc | |||||
| voc_op_test.cc | |||||
| cifar_op_test.cc | |||||
| celeba_op_test.cc | |||||
| take_op_test.cc | |||||
| clue_op_test.cc | |||||
| text_file_op_test.cc | |||||
| filter_op_test.cc | |||||
| concat_op_test.cc | |||||
| jieba_tokenizer_op_test.cc | |||||
| tokenizer_op_test.cc | |||||
| gnn_graph_test.cc | |||||
| coco_op_test.cc | |||||
| fill_op_test.cc | |||||
| mask_test.cc | |||||
| trucate_pair_test.cc | |||||
| ) | |||||
| common/common.cc | |||||
| common/cvop_common.cc | |||||
| batch_op_test.cc | |||||
| bit_functions_test.cc | |||||
| storage_container_test.cc | |||||
| treap_test.cc | |||||
| interrupt_test.cc | |||||
| image_folder_op_test.cc | |||||
| buddy_test.cc | |||||
| arena_test.cc | |||||
| btree_test.cc | |||||
| center_crop_op_test.cc | |||||
| channel_swap_test.cc | |||||
| circular_pool_test.cc | |||||
| client_config_test.cc | |||||
| connector_test.cc | |||||
| datatype_test.cc | |||||
| decode_op_test.cc | |||||
| execution_tree_test.cc | |||||
| global_context_test.cc | |||||
| main_test.cc | |||||
| map_op_test.cc | |||||
| mind_record_op_test.cc | |||||
| memory_pool_test.cc | |||||
| normalize_op_test.cc | |||||
| one_hot_op_test.cc | |||||
| pad_end_op_test.cc | |||||
| path_test.cc | |||||
| project_op_test.cc | |||||
| queue_test.cc | |||||
| random_crop_op_test.cc | |||||
| random_crop_decode_resize_op_test.cc | |||||
| random_crop_and_resize_op_test.cc | |||||
| random_color_adjust_op_test.cc | |||||
| random_horizontal_flip_op_test.cc | |||||
| random_resize_op_test.cc | |||||
| random_rotation_op_test.cc | |||||
| random_vertical_flip_op_test.cc | |||||
| rename_op_test.cc | |||||
| repeat_op_test.cc | |||||
| skip_op_test.cc | |||||
| rescale_op_test.cc | |||||
| resize_bilinear_op_test.cc | |||||
| resize_op_test.cc | |||||
| shuffle_op_test.cc | |||||
| stand_alone_samplers_test.cc | |||||
| status_test.cc | |||||
| storage_op_test.cc | |||||
| task_manager_test.cc | |||||
| tensor_test.cc | |||||
| tensor_string_test.cc | |||||
| tensorshape_test.cc | |||||
| tfReader_op_test.cc | |||||
| to_float16_op_test.cc | |||||
| type_cast_op_test.cc | |||||
| zip_op_test.cc | |||||
| random_resize_op_test.cc | |||||
| subset_random_sampler_test.cc | |||||
| weighted_random_sampler_test.cc | |||||
| mnist_op_test.cc | |||||
| manifest_op_test.cc | |||||
| voc_op_test.cc | |||||
| cifar_op_test.cc | |||||
| celeba_op_test.cc | |||||
| take_op_test.cc | |||||
| clue_op_test.cc | |||||
| text_file_op_test.cc | |||||
| filter_op_test.cc | |||||
| concat_op_test.cc | |||||
| jieba_tokenizer_op_test.cc | |||||
| tokenizer_op_test.cc | |||||
| gnn_graph_test.cc | |||||
| coco_op_test.cc | |||||
| fill_op_test.cc | |||||
| mask_test.cc | |||||
| trucate_pair_test.cc | |||||
| concatenate_op_test.cc | |||||
| ) | |||||
| add_executable(de_ut_tests ${DE_UT_SRCS}) | add_executable(de_ut_tests ${DE_UT_SRCS}) | ||||
| @@ -88,8 +89,8 @@ target_link_libraries(de_ut_tests PRIVATE _c_dataengine pybind11::embed ${GTEST_ | |||||
| gtest_discover_tests(de_ut_tests WORKING_DIRECTORY ${Project_DIR}/tests/dataset) | gtest_discover_tests(de_ut_tests WORKING_DIRECTORY ${Project_DIR}/tests/dataset) | ||||
| install(TARGETS de_ut_tests | install(TARGETS de_ut_tests | ||||
| RUNTIME DESTINATION test) | |||||
| RUNTIME DESTINATION test) | |||||
| # For internal testing only. | # For internal testing only. | ||||
| install(DIRECTORY ${Project_DIR}/tests/dataset/data/ | install(DIRECTORY ${Project_DIR}/tests/dataset/data/ | ||||
| DESTINATION test/data) | |||||
| DESTINATION test/data) | |||||
| @@ -0,0 +1,66 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/common.h" | |||||
| #include "dataset/kernels/data/concatenate_op.h" | |||||
| #include "utils/log_adapter.h" | |||||
| using namespace mindspore::dataset; | |||||
| using mindspore::LogStream; | |||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::MsLogLevel::INFO; | |||||
| class MindDataTestConcatenateOp : public UT::Common { | |||||
| protected: | |||||
| MindDataTestConcatenateOp() {} | |||||
| }; | |||||
| TEST_F(MindDataTestConcatenateOp, TestOp) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestConcatenate-TestOp."; | |||||
| uint64_t labels[3] = {1, 1, 2}; | |||||
| TensorShape shape({3}); | |||||
| std::shared_ptr<Tensor> input = | |||||
| std::make_shared<Tensor>(shape, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(labels)); | |||||
| uint64_t append_labels[3] = {4, 4, 4}; | |||||
| std::shared_ptr<Tensor> append = | |||||
| std::make_shared<Tensor>(shape, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(append_labels)); | |||||
| std::shared_ptr<Tensor> output; | |||||
| std::unique_ptr<ConcatenateOp> op(new ConcatenateOp(0, nullptr, append)); | |||||
| TensorRow in; | |||||
| in.push_back(input); | |||||
| TensorRow out_row; | |||||
| Status s = op->Compute(in, &out_row); | |||||
| uint64_t out[6] = {1, 1, 2, 4, 4, 4}; | |||||
| std::shared_ptr<Tensor> expected = | |||||
| std::make_shared<Tensor>(TensorShape{6}, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(out)); | |||||
| output = out_row[0]; | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| ASSERT_TRUE(output->shape() == expected->shape()); | |||||
| ASSERT_TRUE(output->type() == expected->type()); | |||||
| MS_LOG(DEBUG) << *output << std::endl; | |||||
| MS_LOG(DEBUG) << *expected << std::endl; | |||||
| ASSERT_TRUE(*output == *expected); | |||||
| // std::vector<TensorShape> inputs = {TensorShape({3})}; | |||||
| // std::vector<TensorShape> outputs = {}; | |||||
| // s = op->OutputShape(inputs, outputs); | |||||
| // EXPECT_TRUE(s.IsOk()); | |||||
| // ASSERT_TRUE(outputs[0] == TensorShape{6}); | |||||
| // MS_LOG(INFO) << "MindDataTestConcatenateOp-TestOp end."; | |||||
| } | |||||
| @@ -141,7 +141,6 @@ TEST_F(MindDataTestTensorDE, InsertTensor) { | |||||
| std::shared_ptr<Tensor> t4; | std::shared_ptr<Tensor> t4; | ||||
| Tensor::CreateTensor(&t4, z, TensorShape({2, 3})); | Tensor::CreateTensor(&t4, z, TensorShape({2, 3})); | ||||
| ASSERT_EQ(*t == *t4, true); | ASSERT_EQ(*t == *t4, true); | ||||
| std::shared_ptr<Tensor> t5; | std::shared_ptr<Tensor> t5; | ||||
| @@ -407,3 +406,30 @@ TEST_F(MindDataTestTensorDE, TensorSlice) { | |||||
| t->Slice(&t2, std::vector<dsize_t>{0, 1, 2, 3, 4}); | t->Slice(&t2, std::vector<dsize_t>{0, 1, 2, 3, 4}); | ||||
| ASSERT_EQ(*t2, *t); | ASSERT_EQ(*t2, *t); | ||||
| } | } | ||||
| TEST_F(MindDataTestTensorDE, TensorConcatenate) { | |||||
| std::vector<uint32_t> values1 = {1, 2, 3, 0, 0, 0}; | |||||
| std::vector<uint32_t> values2 = {4, 5, 6}; | |||||
| std::vector<uint32_t> expected = {1, 2, 3, 4, 5, 6}; | |||||
| std::shared_ptr<Tensor> t1; | |||||
| Tensor::CreateTensor(&t1, values1); | |||||
| std::shared_ptr<Tensor> t2; | |||||
| Tensor::CreateTensor(&t2, values2); | |||||
| std::shared_ptr<Tensor> out; | |||||
| Tensor::CreateTensor(&out, expected); | |||||
| Status s = t1->Concatenate({3}, t2); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| auto i = out->begin<uint32_t>(); | |||||
| auto j = t1->begin<uint32_t>(); | |||||
| for (; i != out->end<uint32_t>(); i++, j++) { | |||||
| ASSERT_TRUE(*i == *j); | |||||
| } | |||||
| // should fail if the concatenated vector is too large | |||||
| s = t1->Concatenate({5}, t2); | |||||
| EXPECT_FALSE(s.IsOk()); | |||||
| } | |||||
| @@ -0,0 +1,175 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """ | |||||
| Testing concatenate op | |||||
| """ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.dataset as ds | |||||
| import mindspore.dataset.transforms.c_transforms as data_trans | |||||
| def test_concatenate_op_all(): | |||||
| def gen(): | |||||
| yield (np.array([5., 6., 7., 8.], dtype=np.float),) | |||||
| prepend_tensor = np.array([1.4, 2., 3., 4., 4.5], dtype=np.float) | |||||
| append_tensor = np.array([9., 10.3, 11., 12.], dtype=np.float) | |||||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||||
| concatenate_op = data_trans.Concatenate(0, prepend_tensor, append_tensor) | |||||
| data = data.map(input_columns=["col"], operations=concatenate_op) | |||||
| expected = np.array([1.4, 2., 3., 4., 4.5, 5., 6., 7., 8., 9., 10.3, | |||||
| 11., 12.]) | |||||
| for data_row in data: | |||||
| np.testing.assert_array_equal(data_row[0], expected) | |||||
| def test_concatenate_op_none(): | |||||
| def gen(): | |||||
| yield (np.array([5., 6., 7., 8.], dtype=np.float),) | |||||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||||
| concatenate_op = data_trans.Concatenate() | |||||
| data = data.map(input_columns=["col"], operations=concatenate_op) | |||||
| for data_row in data: | |||||
| np.testing.assert_array_equal(data_row[0], np.array([5., 6., 7., 8.], dtype=np.float)) | |||||
| def test_concatenate_op_string(): | |||||
| def gen(): | |||||
| yield (np.array(["ss", "ad"], dtype='S'),) | |||||
| prepend_tensor = np.array(["dw", "df"], dtype='S') | |||||
| append_tensor = np.array(["dwsdf", "df"], dtype='S') | |||||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||||
| concatenate_op = data_trans.Concatenate(0, prepend_tensor, append_tensor) | |||||
| data = data.map(input_columns=["col"], operations=concatenate_op) | |||||
| expected = np.array(["dw", "df", "ss", "ad", "dwsdf", "df"], dtype='S') | |||||
| for data_row in data: | |||||
| np.testing.assert_array_equal(data_row[0], expected) | |||||
| def test_concatenate_op_multi_input_string(): | |||||
| prepend_tensor = np.array(["dw", "df"], dtype='S') | |||||
| append_tensor = np.array(["dwsdf", "df"], dtype='S') | |||||
| data = ([["1", "2", "d"]], [["3", "4", "e"]]) | |||||
| data = ds.NumpySlicesDataset(data, column_names=["col1", "col2"]) | |||||
| concatenate_op = data_trans.Concatenate(0, prepend=prepend_tensor, append=append_tensor) | |||||
| data = data.map(input_columns=["col1", "col2"], columns_order=["out1"], output_columns=["out1"], | |||||
| operations=concatenate_op) | |||||
| expected = np.array(["dw", "df", "1", "2", "d", "3", "4", "e", "dwsdf", "df"], dtype='S') | |||||
| for data_row in data: | |||||
| np.testing.assert_array_equal(data_row[0], expected) | |||||
| def test_concatenate_op_multi_input_numeric(): | |||||
| prepend_tensor = np.array([3, 5]) | |||||
| data = ([[1, 2]], [[3, 4]]) | |||||
| data = ds.NumpySlicesDataset(data, column_names=["col1", "col2"]) | |||||
| concatenate_op = data_trans.Concatenate(0, prepend=prepend_tensor) | |||||
| data = data.map(input_columns=["col1", "col2"], columns_order=["out1"], output_columns=["out1"], | |||||
| operations=concatenate_op) | |||||
| expected = np.array([3, 5, 1, 2, 3, 4]) | |||||
| for data_row in data: | |||||
| np.testing.assert_array_equal(data_row[0], expected) | |||||
| def test_concatenate_op_type_mismatch(): | |||||
| def gen(): | |||||
| yield (np.array([3, 4], dtype=np.float),) | |||||
| prepend_tensor = np.array(["ss", "ad"], dtype='S') | |||||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||||
| concatenate_op = data_trans.Concatenate(0, prepend_tensor) | |||||
| data = data.map(input_columns=["col"], operations=concatenate_op) | |||||
| with pytest.raises(RuntimeError) as error_info: | |||||
| for _ in data: | |||||
| pass | |||||
| assert "Tensor types do not match" in repr(error_info.value) | |||||
| def test_concatenate_op_type_mismatch2(): | |||||
| def gen(): | |||||
| yield (np.array(["ss", "ad"], dtype='S'),) | |||||
| prepend_tensor = np.array([3, 5], dtype=np.float) | |||||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||||
| concatenate_op = data_trans.Concatenate(0, prepend_tensor) | |||||
| data = data.map(input_columns=["col"], operations=concatenate_op) | |||||
| with pytest.raises(RuntimeError) as error_info: | |||||
| for _ in data: | |||||
| pass | |||||
| assert "Tensor types do not match" in repr(error_info.value) | |||||
| def test_concatenate_op_incorrect_dim(): | |||||
| def gen(): | |||||
| yield (np.array([["ss", "ad"], ["ss", "ad"]], dtype='S'),) | |||||
| prepend_tensor = np.array([3, 5], dtype=np.float) | |||||
| concatenate_op = data_trans.Concatenate(0, prepend_tensor) | |||||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||||
| data = data.map(input_columns=["col"], operations=concatenate_op) | |||||
| with pytest.raises(RuntimeError) as error_info: | |||||
| for _ in data: | |||||
| pass | |||||
| assert "Only 1D tensors supported" in repr(error_info.value) | |||||
| def test_concatenate_op_wrong_axis(): | |||||
| with pytest.raises(ValueError) as error_info: | |||||
| data_trans.Concatenate(2) | |||||
| assert "only 1D concatenation supported." in repr(error_info.value) | |||||
| def test_concatenate_op_incorrect_input_dim(): | |||||
| def gen(): | |||||
| yield (np.array(["ss", "ad"], dtype='S'),) | |||||
| prepend_tensor = np.array([["ss", "ad"], ["ss", "ad"]], dtype='S') | |||||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||||
| concatenate_op = data_trans.Concatenate(0, prepend_tensor) | |||||
| data = data.map(input_columns=["col"], operations=concatenate_op) | |||||
| with pytest.raises(RuntimeError) as error_info: | |||||
| for _ in data: | |||||
| pass | |||||
| assert "Only 1D tensors supported" in repr(error_info.value) | |||||
| if __name__ == "__main__": | |||||
| test_concatenate_op_all() | |||||
| test_concatenate_op_none() | |||||
| test_concatenate_op_string() | |||||
| test_concatenate_op_type_mismatch() | |||||
| test_concatenate_op_type_mismatch2() | |||||
| test_concatenate_op_incorrect_dim() | |||||
| test_concatenate_op_incorrect_input_dim() | |||||
| test_concatenate_op_multi_input_numeric() | |||||
| test_concatenate_op_multi_input_string() | |||||
| test_concatenate_op_wrong_axis() | |||||