You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

concatenate_op.cc 2.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/kernels/data/concatenate_op.h"
  17. #include "dataset/core/tensor.h"
  18. #include "dataset/kernels/data/data_utils.h"
  19. #include "dataset/kernels/tensor_op.h"
  20. namespace mindspore {
  21. namespace dataset {
  22. Status ConcatenateOp::Compute(const TensorRow &input, TensorRow *output) {
  23. IO_CHECK_VECTOR(input, output);
  24. RETURN_IF_NOT_OK(Concatenate(input, output, axis_, prepend_, append_));
  25. return Status::OK();
  26. }
  27. Status ConcatenateOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
  28. RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
  29. std::vector<TensorShape> inputs_copy;
  30. inputs_copy.push_back(inputs[0].Squeeze());
  31. CHECK_FAIL_RETURN_UNEXPECTED(inputs.at(0).Rank() == 1, "Only 1D input tensors supported");
  32. outputs.clear();
  33. dsize_t output_shape = 0;
  34. output_shape = output_shape + inputs.at(0).NumOfElements();
  35. if (prepend_ != nullptr) {
  36. CHECK_FAIL_RETURN_UNEXPECTED(prepend_->shape().Rank() == 1, "Only 1D prepend tensors supported");
  37. output_shape = output_shape + prepend_->shape().NumOfElements();
  38. }
  39. if (append_ != nullptr) {
  40. CHECK_FAIL_RETURN_UNEXPECTED(append_->shape().Rank() == 1, "Only 1D append tensors supported");
  41. output_shape = output_shape + append_->shape().NumOfElements();
  42. }
  43. outputs.emplace_back(std::vector<dsize_t>{output_shape});
  44. return Status::OK();
  45. }
  46. } // namespace dataset
  47. } // namespace mindspore