remove graphengine changes concat op Truncate Pair concat_op remove graph engine changestags/v0.5.0-beta
| @@ -17,6 +17,7 @@ | |||
| #include "dataset/api/de_pipeline.h" | |||
| #include "dataset/kernels/no_op.h" | |||
| #include "dataset/kernels/data/concatenate_op.h" | |||
| #include "dataset/kernels/data/one_hot_op.h" | |||
| #include "dataset/kernels/image/center_crop_op.h" | |||
| #include "dataset/kernels/image/cut_out_op.h" | |||
| @@ -434,6 +435,11 @@ void bindTensorOps2(py::module *m) { | |||
| *m, "TruncateSequencePairOp", "Tensor operation to truncate two tensors to a max_length") | |||
| .def(py::init<int64_t>()); | |||
| (void)py::class_<ConcatenateOp, TensorOp, std::shared_ptr<ConcatenateOp>>(*m, "ConcatenateOp", | |||
| "Tensor operation concatenate tensors.") | |||
| .def(py::init<int8_t, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>(), py::arg("axis"), | |||
| py::arg("prepend").none(true), py::arg("append").none(true)); | |||
| (void)py::class_<RandomRotationOp, TensorOp, std::shared_ptr<RandomRotationOp>>( | |||
| *m, "RandomRotationOp", | |||
| "Tensor operation to apply RandomRotation." | |||
| @@ -589,11 +589,13 @@ Status Tensor::StartAddrOfIndex(std::vector<dsize_t> ind, uchar **start_addr_of_ | |||
| if (type() == DataType::DE_STRING) { | |||
| RETURN_STATUS_UNEXPECTED("StartAddrOfIndex does not support string tensors yet."); | |||
| } | |||
| dsize_t flat_ind; | |||
| std::vector<dsize_t> t_shape = shape().AsVector(); | |||
| std::vector<dsize_t> r(t_shape.begin() + ind.size(), t_shape.end()); | |||
| *remaining = TensorShape(r); | |||
| ind.resize(this->Rank(), 0); // same as -> while (ind.size() < this->Rank()) ind.push_back(0); | |||
| RETURN_IF_NOT_OK(shape_.ToFlatIndex(ind, &flat_ind)); | |||
| // check if GetBuffer() returns null, we should flag this as an error, this sanity check will only | |||
| // be true is the tensor failed to allocate memory. | |||
| @@ -634,6 +636,39 @@ Status Tensor::InsertTensor(const std::vector<dsize_t> &ind, const std::shared_p | |||
| } | |||
| } | |||
| Status Tensor::Concatenate(const std::vector<dsize_t> &index, const std::shared_ptr<Tensor> &tensor) { | |||
| std::string err_msg; | |||
| err_msg += (index.size() != 1) ? "[Tensor] only supports 1d concatenation \n" : ""; | |||
| err_msg += (type() == DataType::DE_STRING) ? "[Tensor] Cannot batch tensors of type string\n" : ""; | |||
| err_msg += (!shape().known() || !tensor->shape().known()) ? "[Tensor] unknown shape\n" : ""; | |||
| err_msg += | |||
| (index.at(0) + tensor->shape().NumOfElements() > this->shape().NumOfElements()) ? "[Tensor] incorrect index\n" : ""; | |||
| err_msg += tensor->type().SizeInBytes() != this->type().SizeInBytes() ? "[Tensor] incorrect datatype\n" : ""; | |||
| uchar *start_addr_of_ind = nullptr; | |||
| TensorShape remaining_shape = tensor->shape(); | |||
| StartAddrOfIndex(index, &start_addr_of_ind, &remaining_shape); | |||
| err_msg += (start_addr_of_ind == nullptr) ? "Failed to create memory for Tensor.\n" : ""; | |||
| if (!err_msg.empty()) { | |||
| MS_LOG(DEBUG) << "Insert tensor message: " << err_msg; | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } else { | |||
| int ret_code = | |||
| memcpy_s(start_addr_of_ind, tensor->SizeInBytes(), tensor->GetMutableBuffer(), tensor->SizeInBytes()); | |||
| if (ret_code == 0) { | |||
| return Status::OK(); | |||
| } else { | |||
| err_msg += "[Tensor] error in memcpy_s when inserting tensor\n"; | |||
| MS_LOG(DEBUG) << "Tensor message: " << err_msg; | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| } | |||
| } | |||
| Status Tensor::ExpandDim(const dsize_t &axis) { | |||
| if (axis > Rank()) { | |||
| std::string err = "Axis is out of bound"; | |||
| @@ -372,6 +372,9 @@ class Tensor { | |||
| static Status GetBufferInfo(Tensor &t, py::buffer_info *out); | |||
| // Concatenate based on given tensor, can fill in current tensor with a smaller one, unlike InsertTensor | |||
| Status Concatenate(const std::vector<dsize_t> &index, const std::shared_ptr<Tensor> &input); | |||
| // TensorIterator is a linear iterator that can be used to iterate over the elements of the Tensor | |||
| // The order elements is as the memory layout (i.e., row-major) [[1,2,3],[4,5,6] --> 1,2,3,4,5,6 | |||
| // @tparam T type of values in the Tensor Iterator | |||
| @@ -94,7 +94,7 @@ class TensorShape { | |||
| // @return | |||
| TensorShape PrependDim(dsize_t dim) const; | |||
| // Insert a new dim at the end of the shape. For example, <2,4> --> PrependDim(4) --> <2,4,4> | |||
| // Insert a new dim at the end of the shape. For example, <2,4> --> AppendDim(4) --> <2,4,4> | |||
| // @param dim | |||
| // @return | |||
| TensorShape AppendDim(dsize_t dim) const; | |||
| @@ -1,12 +1,13 @@ | |||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| add_library(kernels-data OBJECT | |||
| data_utils.cc | |||
| one_hot_op.cc | |||
| pad_end_op.cc | |||
| type_cast_op.cc | |||
| to_float16_op.cc | |||
| fill_op.cc | |||
| slice_op.cc | |||
| mask_op.cc | |||
| ) | |||
| data_utils.cc | |||
| one_hot_op.cc | |||
| pad_end_op.cc | |||
| type_cast_op.cc | |||
| to_float16_op.cc | |||
| fill_op.cc | |||
| slice_op.cc | |||
| mask_op.cc | |||
| concatenate_op.cc | |||
| ) | |||
| @@ -0,0 +1,55 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/kernels/data/concatenate_op.h" | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/kernels/data/data_utils.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status ConcatenateOp::Compute(const TensorRow &input, TensorRow *output) { | |||
| IO_CHECK_VECTOR(input, output); | |||
| RETURN_IF_NOT_OK(Concatenate(input, output, axis_, prepend_, append_)); | |||
| return Status::OK(); | |||
| } | |||
| Status ConcatenateOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) { | |||
| RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); | |||
| std::vector<TensorShape> inputs_copy; | |||
| inputs_copy.push_back(inputs[0].Squeeze()); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(inputs.at(0).Rank() == 1, "Only 1D input tensors supported"); | |||
| outputs.clear(); | |||
| dsize_t output_shape = 0; | |||
| output_shape = output_shape + inputs.at(0).NumOfElements(); | |||
| if (prepend_ != nullptr) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(prepend_->shape().Rank() == 1, "Only 1D prepend tensors supported"); | |||
| output_shape = output_shape + prepend_->shape().NumOfElements(); | |||
| } | |||
| if (append_ != nullptr) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(append_->shape().Rank() == 1, "Only 1D append tensors supported"); | |||
| output_shape = output_shape + append_->shape().NumOfElements(); | |||
| } | |||
| outputs.emplace_back(std::vector<dsize_t>{output_shape}); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,66 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_KERNELS_DATA_CONCATENATE_OP_H_ | |||
| #define DATASET_KERNELS_DATA_CONCATENATE_OP_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class ConcatenateOp : public TensorOp { | |||
| public: | |||
| /// Constructor to ConcatenateOp. | |||
| /// @param int8_t axis - axis to concatenate tensors along. | |||
| /// @param std::shared_ptr<Tensor> prepend - prepend tensor. | |||
| /// @param std::shared_ptr<Tensor> append -append tensor. | |||
| explicit ConcatenateOp(int8_t axis, std::shared_ptr<Tensor> prepend, std::shared_ptr<Tensor> append) | |||
| : axis_(axis), prepend_(prepend), append_(append) {} | |||
| ~ConcatenateOp() override = default; | |||
| /// Print method to see which tensor Op this is. | |||
| /// @param std::ostream &out - output stream object. | |||
| void Print(std::ostream &out) const override { out << "ConcatenateOp"; } | |||
| /// Compute method allowing multiple tensors as inputs | |||
| /// @param TensorRow &input - input tensor rows | |||
| /// @param TensorRow *output - output tensor rows | |||
| Status Compute(const TensorRow &input, TensorRow *output) override; | |||
| /// Compute tensor output shape | |||
| /// @param std::vector<TensorShape> &inputs - vector of input tensor shapes | |||
| /// @param std::vector<TensorShape< &outputs - vector of output tensor shapes | |||
| Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override; | |||
| /// Number of inputs the tensor operation accepts | |||
| uint32_t NumInput() override { return 0; } | |||
| private: | |||
| int8_t axis_; | |||
| std::shared_ptr<Tensor> prepend_; | |||
| std::shared_ptr<Tensor> append_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CONCATENATE_OP_H | |||
| @@ -555,5 +555,80 @@ Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outpu | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr<Tensor> prepend, | |||
| std::shared_ptr<Tensor> append) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Rank() == 1, "Only 1D tensors supported"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(axis == 0 || axis == -1, "Only concatenation along the last dimension supported"); | |||
| Tensor::HandleNeg(axis, input[0]->shape().Rank()); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(axis == 0, "Only axis=0 is supported"); | |||
| std::shared_ptr<Tensor> out; | |||
| if (prepend != nullptr) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(prepend->shape().Rank() == 1, "Only 1D tensors supported"); | |||
| RETURN_IF_NOT_OK(ConcatenateHelper(prepend, &out, axis, input[0])); | |||
| } else { | |||
| out = input[0]; | |||
| } | |||
| for (dsize_t i = 1; i < input.size(); i++) { | |||
| std::shared_ptr<Tensor> out_t; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(input[i]->shape().Rank() == 1, "Only 1D tensors supported"); | |||
| RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, input[i])); | |||
| out = out_t; | |||
| } | |||
| std::shared_ptr<Tensor> out_t; | |||
| if (append != nullptr) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(append->shape().Rank() == 1, "Only 1D tensors supported"); | |||
| RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, append)); | |||
| } else { | |||
| out_t = out; | |||
| } | |||
| output->push_back(out_t); | |||
| return Status::OK(); | |||
| } | |||
| Status ConcatenateHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int8_t axis, | |||
| std::shared_ptr<Tensor> append) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(input->type() == append->type(), "Tensor types do not match"); | |||
| TensorShape t({}); | |||
| for (dsize_t i = 0; i < input->shape().Rank(); i++) { | |||
| if (i != axis) { | |||
| t = t.AppendDim(input->shape()[i]); | |||
| } else { | |||
| dsize_t new_shape = input->shape()[i] + append->shape()[i]; | |||
| t = t.AppendDim(new_shape); | |||
| } | |||
| } | |||
| std::shared_ptr<Tensor> out; | |||
| if (input->type().IsNumeric()) { | |||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, t, input->type())); | |||
| RETURN_IF_NOT_OK(out->Concatenate({0}, input)); | |||
| RETURN_IF_NOT_OK(out->Concatenate({input->shape()[0]}, append)); | |||
| *output = out; | |||
| } else { | |||
| std::vector<std::string> strings; | |||
| auto itr = input->begin<std::string_view>(); | |||
| for (; itr != input->end<std::string_view>(); itr++) { | |||
| strings.emplace_back(*itr); | |||
| } | |||
| itr = append->begin<std::string_view>(); | |||
| for (; itr != append->end<std::string_view>(); itr++) { | |||
| strings.emplace_back(*itr); | |||
| } | |||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, t)); | |||
| *output = out; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -23,6 +23,7 @@ | |||
| #include "dataset/core/cv_tensor.h" | |||
| #include "dataset/core/data_type.h" | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/core/tensor_row.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -148,6 +149,14 @@ Status MaskHelper(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Te | |||
| /// @return Status ok/error | |||
| Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::shared_ptr<Tensor> &value, | |||
| RelationalOp op); | |||
| Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr<Tensor> prepend, | |||
| std::shared_ptr<Tensor> append); | |||
| // helper for concat, always append to the input, and pass that to the output | |||
| Status ConcatenateHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int8_t axis, | |||
| std::shared_ptr<Tensor> append); | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -16,13 +16,13 @@ | |||
| This module c_transforms provides common operations, including OneHotOp and TypeCast. | |||
| """ | |||
| from enum import IntEnum | |||
| import numpy as np | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore._c_dataengine as cde | |||
| import numpy as np | |||
| from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op, check_mask_op, check_pad_end | |||
| from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op, check_mask_op, \ | |||
| check_pad_end, check_concat_type | |||
| from ..core.datatypes import mstype_to_detype | |||
| @@ -187,3 +187,19 @@ class PadEnd(cde.PadEndOp): | |||
| if pad_value is not None: | |||
| pad_value = cde.Tensor(np.array(pad_value)) | |||
| super().__init__(cde.TensorShape(pad_shape), pad_value) | |||
| class Concatenate(cde.ConcatenateOp): | |||
| """ | |||
| Tensor operation to prepend and append to a tensor. | |||
| Args: | |||
| axis (int, optional): axis to concatenate the tensors along (Default=0). | |||
| prepend (np.array, optional): numpy array to be prepended to the already concatenated tensors (Default=None). | |||
| append (np.array, optional): numpy array to be appended to the already concatenated tensors (Default=None). | |||
| """ | |||
| @check_concat_type | |||
| def __init__(self, axis=0, prepend=None, append=None): | |||
| # add some validations here later | |||
| super().__init__(axis, prepend, append) | |||
| @@ -15,7 +15,9 @@ | |||
| """Validators for TensorOps. | |||
| """ | |||
| from functools import wraps | |||
| import numpy as np | |||
| import mindspore._c_dataengine as cde | |||
| from mindspore._c_expression import typing | |||
| # POS_INT_MIN is used to limit values from starting from 0 | |||
| @@ -230,10 +232,11 @@ def check_mask_op(method): | |||
| if operator is None: | |||
| raise ValueError("operator is not provided.") | |||
| from .c_transforms import Relational | |||
| if constant is None: | |||
| raise ValueError("constant is not provided.") | |||
| from .c_transforms import Relational | |||
| if not isinstance(operator, Relational): | |||
| raise TypeError("operator is not a Relational operator enum.") | |||
| @@ -282,3 +285,46 @@ def check_pad_end(method): | |||
| return method(self, **kwargs) | |||
| return new_method | |||
| def check_concat_type(method): | |||
| """Wrapper method to check the parameters of concatenation op.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| axis, prepend, append = (list(args) + 3 * [None])[:3] | |||
| if "prepend" in kwargs: | |||
| prepend = kwargs.get("prepend") | |||
| if "append" in kwargs: | |||
| append = kwargs.get("append") | |||
| if "axis" in kwargs: | |||
| axis = kwargs.get("axis") | |||
| if not isinstance(axis, (type(None), int)): | |||
| raise TypeError("axis type is not valid, must be None or an integer.") | |||
| if isinstance(axis, type(None)): | |||
| axis = 0 | |||
| if axis not in (None, 0, -1): | |||
| raise ValueError("only 1D concatenation supported.") | |||
| if not isinstance(prepend, (type(None), np.ndarray)): | |||
| raise ValueError("prepend type is not valid, must be None for no prepend tensor or a numpy array.") | |||
| if not isinstance(append, (type(None), np.ndarray)): | |||
| raise ValueError("append type is not valid, must be None for no append tensor or a numpy array.") | |||
| if isinstance(prepend, np.ndarray): | |||
| prepend = cde.Tensor(prepend) | |||
| if isinstance(append, np.ndarray): | |||
| append = cde.Tensor(append) | |||
| kwargs["axis"] = axis | |||
| kwargs["prepend"] = prepend | |||
| kwargs["append"] = append | |||
| return method(self, **kwargs) | |||
| return new_method | |||
| @@ -1,83 +1,84 @@ | |||
| include(GoogleTest) | |||
| SET(DE_UT_SRCS | |||
| common/common.cc | |||
| common/cvop_common.cc | |||
| batch_op_test.cc | |||
| bit_functions_test.cc | |||
| storage_container_test.cc | |||
| treap_test.cc | |||
| interrupt_test.cc | |||
| image_folder_op_test.cc | |||
| buddy_test.cc | |||
| arena_test.cc | |||
| btree_test.cc | |||
| center_crop_op_test.cc | |||
| channel_swap_test.cc | |||
| circular_pool_test.cc | |||
| client_config_test.cc | |||
| connector_test.cc | |||
| datatype_test.cc | |||
| decode_op_test.cc | |||
| execution_tree_test.cc | |||
| global_context_test.cc | |||
| main_test.cc | |||
| map_op_test.cc | |||
| mind_record_op_test.cc | |||
| memory_pool_test.cc | |||
| normalize_op_test.cc | |||
| one_hot_op_test.cc | |||
| pad_end_op_test.cc | |||
| path_test.cc | |||
| project_op_test.cc | |||
| queue_test.cc | |||
| random_crop_op_test.cc | |||
| random_crop_decode_resize_op_test.cc | |||
| random_crop_and_resize_op_test.cc | |||
| random_color_adjust_op_test.cc | |||
| random_horizontal_flip_op_test.cc | |||
| random_resize_op_test.cc | |||
| random_rotation_op_test.cc | |||
| random_vertical_flip_op_test.cc | |||
| rename_op_test.cc | |||
| repeat_op_test.cc | |||
| skip_op_test.cc | |||
| rescale_op_test.cc | |||
| resize_bilinear_op_test.cc | |||
| resize_op_test.cc | |||
| shuffle_op_test.cc | |||
| stand_alone_samplers_test.cc | |||
| status_test.cc | |||
| storage_op_test.cc | |||
| task_manager_test.cc | |||
| tensor_test.cc | |||
| tensor_string_test.cc | |||
| tensorshape_test.cc | |||
| tfReader_op_test.cc | |||
| to_float16_op_test.cc | |||
| type_cast_op_test.cc | |||
| zip_op_test.cc | |||
| random_resize_op_test.cc | |||
| subset_random_sampler_test.cc | |||
| weighted_random_sampler_test.cc | |||
| mnist_op_test.cc | |||
| manifest_op_test.cc | |||
| voc_op_test.cc | |||
| cifar_op_test.cc | |||
| celeba_op_test.cc | |||
| take_op_test.cc | |||
| clue_op_test.cc | |||
| text_file_op_test.cc | |||
| filter_op_test.cc | |||
| concat_op_test.cc | |||
| jieba_tokenizer_op_test.cc | |||
| tokenizer_op_test.cc | |||
| gnn_graph_test.cc | |||
| coco_op_test.cc | |||
| fill_op_test.cc | |||
| mask_test.cc | |||
| trucate_pair_test.cc | |||
| ) | |||
| common/common.cc | |||
| common/cvop_common.cc | |||
| batch_op_test.cc | |||
| bit_functions_test.cc | |||
| storage_container_test.cc | |||
| treap_test.cc | |||
| interrupt_test.cc | |||
| image_folder_op_test.cc | |||
| buddy_test.cc | |||
| arena_test.cc | |||
| btree_test.cc | |||
| center_crop_op_test.cc | |||
| channel_swap_test.cc | |||
| circular_pool_test.cc | |||
| client_config_test.cc | |||
| connector_test.cc | |||
| datatype_test.cc | |||
| decode_op_test.cc | |||
| execution_tree_test.cc | |||
| global_context_test.cc | |||
| main_test.cc | |||
| map_op_test.cc | |||
| mind_record_op_test.cc | |||
| memory_pool_test.cc | |||
| normalize_op_test.cc | |||
| one_hot_op_test.cc | |||
| pad_end_op_test.cc | |||
| path_test.cc | |||
| project_op_test.cc | |||
| queue_test.cc | |||
| random_crop_op_test.cc | |||
| random_crop_decode_resize_op_test.cc | |||
| random_crop_and_resize_op_test.cc | |||
| random_color_adjust_op_test.cc | |||
| random_horizontal_flip_op_test.cc | |||
| random_resize_op_test.cc | |||
| random_rotation_op_test.cc | |||
| random_vertical_flip_op_test.cc | |||
| rename_op_test.cc | |||
| repeat_op_test.cc | |||
| skip_op_test.cc | |||
| rescale_op_test.cc | |||
| resize_bilinear_op_test.cc | |||
| resize_op_test.cc | |||
| shuffle_op_test.cc | |||
| stand_alone_samplers_test.cc | |||
| status_test.cc | |||
| storage_op_test.cc | |||
| task_manager_test.cc | |||
| tensor_test.cc | |||
| tensor_string_test.cc | |||
| tensorshape_test.cc | |||
| tfReader_op_test.cc | |||
| to_float16_op_test.cc | |||
| type_cast_op_test.cc | |||
| zip_op_test.cc | |||
| random_resize_op_test.cc | |||
| subset_random_sampler_test.cc | |||
| weighted_random_sampler_test.cc | |||
| mnist_op_test.cc | |||
| manifest_op_test.cc | |||
| voc_op_test.cc | |||
| cifar_op_test.cc | |||
| celeba_op_test.cc | |||
| take_op_test.cc | |||
| clue_op_test.cc | |||
| text_file_op_test.cc | |||
| filter_op_test.cc | |||
| concat_op_test.cc | |||
| jieba_tokenizer_op_test.cc | |||
| tokenizer_op_test.cc | |||
| gnn_graph_test.cc | |||
| coco_op_test.cc | |||
| fill_op_test.cc | |||
| mask_test.cc | |||
| trucate_pair_test.cc | |||
| concatenate_op_test.cc | |||
| ) | |||
| add_executable(de_ut_tests ${DE_UT_SRCS}) | |||
| @@ -88,8 +89,8 @@ target_link_libraries(de_ut_tests PRIVATE _c_dataengine pybind11::embed ${GTEST_ | |||
| gtest_discover_tests(de_ut_tests WORKING_DIRECTORY ${Project_DIR}/tests/dataset) | |||
| install(TARGETS de_ut_tests | |||
| RUNTIME DESTINATION test) | |||
| RUNTIME DESTINATION test) | |||
| # For internal testing only. | |||
| install(DIRECTORY ${Project_DIR}/tests/dataset/data/ | |||
| DESTINATION test/data) | |||
| DESTINATION test/data) | |||
| @@ -0,0 +1,66 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/common.h" | |||
| #include "dataset/kernels/data/concatenate_op.h" | |||
| #include "utils/log_adapter.h" | |||
| using namespace mindspore::dataset; | |||
| using mindspore::LogStream; | |||
| using mindspore::ExceptionType::NoExceptionType; | |||
| using mindspore::MsLogLevel::INFO; | |||
| class MindDataTestConcatenateOp : public UT::Common { | |||
| protected: | |||
| MindDataTestConcatenateOp() {} | |||
| }; | |||
| TEST_F(MindDataTestConcatenateOp, TestOp) { | |||
| MS_LOG(INFO) << "Doing MindDataTestConcatenate-TestOp."; | |||
| uint64_t labels[3] = {1, 1, 2}; | |||
| TensorShape shape({3}); | |||
| std::shared_ptr<Tensor> input = | |||
| std::make_shared<Tensor>(shape, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(labels)); | |||
| uint64_t append_labels[3] = {4, 4, 4}; | |||
| std::shared_ptr<Tensor> append = | |||
| std::make_shared<Tensor>(shape, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(append_labels)); | |||
| std::shared_ptr<Tensor> output; | |||
| std::unique_ptr<ConcatenateOp> op(new ConcatenateOp(0, nullptr, append)); | |||
| TensorRow in; | |||
| in.push_back(input); | |||
| TensorRow out_row; | |||
| Status s = op->Compute(in, &out_row); | |||
| uint64_t out[6] = {1, 1, 2, 4, 4, 4}; | |||
| std::shared_ptr<Tensor> expected = | |||
| std::make_shared<Tensor>(TensorShape{6}, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(out)); | |||
| output = out_row[0]; | |||
| EXPECT_TRUE(s.IsOk()); | |||
| ASSERT_TRUE(output->shape() == expected->shape()); | |||
| ASSERT_TRUE(output->type() == expected->type()); | |||
| MS_LOG(DEBUG) << *output << std::endl; | |||
| MS_LOG(DEBUG) << *expected << std::endl; | |||
| ASSERT_TRUE(*output == *expected); | |||
| // std::vector<TensorShape> inputs = {TensorShape({3})}; | |||
| // std::vector<TensorShape> outputs = {}; | |||
| // s = op->OutputShape(inputs, outputs); | |||
| // EXPECT_TRUE(s.IsOk()); | |||
| // ASSERT_TRUE(outputs[0] == TensorShape{6}); | |||
| // MS_LOG(INFO) << "MindDataTestConcatenateOp-TestOp end."; | |||
| } | |||
| @@ -141,7 +141,6 @@ TEST_F(MindDataTestTensorDE, InsertTensor) { | |||
| std::shared_ptr<Tensor> t4; | |||
| Tensor::CreateTensor(&t4, z, TensorShape({2, 3})); | |||
| ASSERT_EQ(*t == *t4, true); | |||
| std::shared_ptr<Tensor> t5; | |||
| @@ -407,3 +406,30 @@ TEST_F(MindDataTestTensorDE, TensorSlice) { | |||
| t->Slice(&t2, std::vector<dsize_t>{0, 1, 2, 3, 4}); | |||
| ASSERT_EQ(*t2, *t); | |||
| } | |||
| TEST_F(MindDataTestTensorDE, TensorConcatenate) { | |||
| std::vector<uint32_t> values1 = {1, 2, 3, 0, 0, 0}; | |||
| std::vector<uint32_t> values2 = {4, 5, 6}; | |||
| std::vector<uint32_t> expected = {1, 2, 3, 4, 5, 6}; | |||
| std::shared_ptr<Tensor> t1; | |||
| Tensor::CreateTensor(&t1, values1); | |||
| std::shared_ptr<Tensor> t2; | |||
| Tensor::CreateTensor(&t2, values2); | |||
| std::shared_ptr<Tensor> out; | |||
| Tensor::CreateTensor(&out, expected); | |||
| Status s = t1->Concatenate({3}, t2); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| auto i = out->begin<uint32_t>(); | |||
| auto j = t1->begin<uint32_t>(); | |||
| for (; i != out->end<uint32_t>(); i++, j++) { | |||
| ASSERT_TRUE(*i == *j); | |||
| } | |||
| // should fail if the concatenated vector is too large | |||
| s = t1->Concatenate({5}, t2); | |||
| EXPECT_FALSE(s.IsOk()); | |||
| } | |||
| @@ -0,0 +1,175 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing concatenate op | |||
| """ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.c_transforms as data_trans | |||
| def test_concatenate_op_all(): | |||
| def gen(): | |||
| yield (np.array([5., 6., 7., 8.], dtype=np.float),) | |||
| prepend_tensor = np.array([1.4, 2., 3., 4., 4.5], dtype=np.float) | |||
| append_tensor = np.array([9., 10.3, 11., 12.], dtype=np.float) | |||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||
| concatenate_op = data_trans.Concatenate(0, prepend_tensor, append_tensor) | |||
| data = data.map(input_columns=["col"], operations=concatenate_op) | |||
| expected = np.array([1.4, 2., 3., 4., 4.5, 5., 6., 7., 8., 9., 10.3, | |||
| 11., 12.]) | |||
| for data_row in data: | |||
| np.testing.assert_array_equal(data_row[0], expected) | |||
| def test_concatenate_op_none(): | |||
| def gen(): | |||
| yield (np.array([5., 6., 7., 8.], dtype=np.float),) | |||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||
| concatenate_op = data_trans.Concatenate() | |||
| data = data.map(input_columns=["col"], operations=concatenate_op) | |||
| for data_row in data: | |||
| np.testing.assert_array_equal(data_row[0], np.array([5., 6., 7., 8.], dtype=np.float)) | |||
| def test_concatenate_op_string(): | |||
| def gen(): | |||
| yield (np.array(["ss", "ad"], dtype='S'),) | |||
| prepend_tensor = np.array(["dw", "df"], dtype='S') | |||
| append_tensor = np.array(["dwsdf", "df"], dtype='S') | |||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||
| concatenate_op = data_trans.Concatenate(0, prepend_tensor, append_tensor) | |||
| data = data.map(input_columns=["col"], operations=concatenate_op) | |||
| expected = np.array(["dw", "df", "ss", "ad", "dwsdf", "df"], dtype='S') | |||
| for data_row in data: | |||
| np.testing.assert_array_equal(data_row[0], expected) | |||
| def test_concatenate_op_multi_input_string(): | |||
| prepend_tensor = np.array(["dw", "df"], dtype='S') | |||
| append_tensor = np.array(["dwsdf", "df"], dtype='S') | |||
| data = ([["1", "2", "d"]], [["3", "4", "e"]]) | |||
| data = ds.NumpySlicesDataset(data, column_names=["col1", "col2"]) | |||
| concatenate_op = data_trans.Concatenate(0, prepend=prepend_tensor, append=append_tensor) | |||
| data = data.map(input_columns=["col1", "col2"], columns_order=["out1"], output_columns=["out1"], | |||
| operations=concatenate_op) | |||
| expected = np.array(["dw", "df", "1", "2", "d", "3", "4", "e", "dwsdf", "df"], dtype='S') | |||
| for data_row in data: | |||
| np.testing.assert_array_equal(data_row[0], expected) | |||
| def test_concatenate_op_multi_input_numeric(): | |||
| prepend_tensor = np.array([3, 5]) | |||
| data = ([[1, 2]], [[3, 4]]) | |||
| data = ds.NumpySlicesDataset(data, column_names=["col1", "col2"]) | |||
| concatenate_op = data_trans.Concatenate(0, prepend=prepend_tensor) | |||
| data = data.map(input_columns=["col1", "col2"], columns_order=["out1"], output_columns=["out1"], | |||
| operations=concatenate_op) | |||
| expected = np.array([3, 5, 1, 2, 3, 4]) | |||
| for data_row in data: | |||
| np.testing.assert_array_equal(data_row[0], expected) | |||
| def test_concatenate_op_type_mismatch(): | |||
| def gen(): | |||
| yield (np.array([3, 4], dtype=np.float),) | |||
| prepend_tensor = np.array(["ss", "ad"], dtype='S') | |||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||
| concatenate_op = data_trans.Concatenate(0, prepend_tensor) | |||
| data = data.map(input_columns=["col"], operations=concatenate_op) | |||
| with pytest.raises(RuntimeError) as error_info: | |||
| for _ in data: | |||
| pass | |||
| assert "Tensor types do not match" in repr(error_info.value) | |||
| def test_concatenate_op_type_mismatch2(): | |||
| def gen(): | |||
| yield (np.array(["ss", "ad"], dtype='S'),) | |||
| prepend_tensor = np.array([3, 5], dtype=np.float) | |||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||
| concatenate_op = data_trans.Concatenate(0, prepend_tensor) | |||
| data = data.map(input_columns=["col"], operations=concatenate_op) | |||
| with pytest.raises(RuntimeError) as error_info: | |||
| for _ in data: | |||
| pass | |||
| assert "Tensor types do not match" in repr(error_info.value) | |||
| def test_concatenate_op_incorrect_dim(): | |||
| def gen(): | |||
| yield (np.array([["ss", "ad"], ["ss", "ad"]], dtype='S'),) | |||
| prepend_tensor = np.array([3, 5], dtype=np.float) | |||
| concatenate_op = data_trans.Concatenate(0, prepend_tensor) | |||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||
| data = data.map(input_columns=["col"], operations=concatenate_op) | |||
| with pytest.raises(RuntimeError) as error_info: | |||
| for _ in data: | |||
| pass | |||
| assert "Only 1D tensors supported" in repr(error_info.value) | |||
| def test_concatenate_op_wrong_axis(): | |||
| with pytest.raises(ValueError) as error_info: | |||
| data_trans.Concatenate(2) | |||
| assert "only 1D concatenation supported." in repr(error_info.value) | |||
| def test_concatenate_op_incorrect_input_dim(): | |||
| def gen(): | |||
| yield (np.array(["ss", "ad"], dtype='S'),) | |||
| prepend_tensor = np.array([["ss", "ad"], ["ss", "ad"]], dtype='S') | |||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||
| concatenate_op = data_trans.Concatenate(0, prepend_tensor) | |||
| data = data.map(input_columns=["col"], operations=concatenate_op) | |||
| with pytest.raises(RuntimeError) as error_info: | |||
| for _ in data: | |||
| pass | |||
| assert "Only 1D tensors supported" in repr(error_info.value) | |||
| if __name__ == "__main__": | |||
| test_concatenate_op_all() | |||
| test_concatenate_op_none() | |||
| test_concatenate_op_string() | |||
| test_concatenate_op_type_mismatch() | |||
| test_concatenate_op_type_mismatch2() | |||
| test_concatenate_op_incorrect_dim() | |||
| test_concatenate_op_incorrect_input_dim() | |||
| test_concatenate_op_multi_input_numeric() | |||
| test_concatenate_op_multi_input_string() | |||
| test_concatenate_op_wrong_axis() | |||