You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

concatenate_op_test.cc 4.2 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/common.h"
  17. #include "minddata/dataset/kernels/data/concatenate_op.h"
  18. #include "utils/log_adapter.h"
  19. using namespace mindspore::dataset;
  20. using mindspore::LogStream;
  21. using mindspore::ExceptionType::NoExceptionType;
  22. using mindspore::MsLogLevel::INFO;
  23. class MindDataTestConcatenateOp : public UT::Common {
  24. protected:
  25. MindDataTestConcatenateOp() {}
  26. };
  27. TEST_F(MindDataTestConcatenateOp, TestOp) {
  28. MS_LOG(INFO) << "Doing MindDataTestConcatenate-TestOp-SingleRowinput.";
  29. std::vector<uint64_t> labels = {1, 1, 2};
  30. std::shared_ptr<Tensor> input;
  31. Tensor::CreateFromVector(labels, &input);
  32. std::vector<uint64_t> append_labels = {4, 4, 4};
  33. std::shared_ptr<Tensor> append;
  34. Tensor::CreateFromVector(append_labels, &append);
  35. std::shared_ptr<Tensor> output;
  36. std::unique_ptr<ConcatenateOp> op(new ConcatenateOp(0, nullptr, append));
  37. TensorRow in;
  38. in.push_back(input);
  39. TensorRow out_row;
  40. Status s = op->Compute(in, &out_row);
  41. std::vector<uint64_t> out = {1, 1, 2, 4, 4, 4};
  42. std::shared_ptr<Tensor> expected;
  43. Tensor::CreateFromVector(out, &expected);
  44. output = out_row[0];
  45. EXPECT_TRUE(s.IsOk());
  46. ASSERT_TRUE(output->shape() == expected->shape());
  47. ASSERT_TRUE(output->type() == expected->type());
  48. MS_LOG(DEBUG) << *output << std::endl;
  49. MS_LOG(DEBUG) << *expected << std::endl;
  50. ASSERT_TRUE(*output == *expected);
  51. }
  52. TEST_F(MindDataTestConcatenateOp, TestOp2) {
  53. MS_LOG(INFO) << "Doing MindDataTestConcatenate-TestOp2-MultiInput.";
  54. std::vector<uint64_t> labels = {1, 12, 2};
  55. std::shared_ptr<Tensor> row_1;
  56. Tensor::CreateFromVector(labels, &row_1);
  57. std::shared_ptr<Tensor> row_2;
  58. Tensor::CreateFromVector(labels, &row_2);
  59. std::vector<uint64_t> append_labels = {4, 4, 4};
  60. std::shared_ptr<Tensor> append;
  61. Tensor::CreateFromVector(append_labels, &append);
  62. TensorRow tensor_list;
  63. tensor_list.push_back(row_1);
  64. tensor_list.push_back(row_2);
  65. std::shared_ptr<Tensor> output;
  66. std::unique_ptr<ConcatenateOp> op(new ConcatenateOp(0, nullptr, append));
  67. TensorRow out_row;
  68. Status s = op->Compute(tensor_list, &out_row);
  69. std::vector<uint64_t> out = {1, 12, 2, 1, 12, 2, 4, 4, 4};
  70. std::shared_ptr<Tensor> expected;
  71. Tensor::CreateFromVector(out, &expected);
  72. output = out_row[0];
  73. EXPECT_TRUE(s.IsOk());
  74. ASSERT_TRUE(output->shape() == expected->shape());
  75. ASSERT_TRUE(output->type() == expected->type());
  76. MS_LOG(DEBUG) << *output << std::endl;
  77. MS_LOG(DEBUG) << *expected << std::endl;
  78. ASSERT_TRUE(*output == *expected);
  79. }
  80. TEST_F(MindDataTestConcatenateOp, TestOp3) {
  81. MS_LOG(INFO) << "Doing MindDataTestConcatenate-TestOp3-Strings.";
  82. std::vector<std::string> labels = {"hello", "bye"};
  83. std::shared_ptr<Tensor> row_1;
  84. Tensor::CreateFromVector(labels, &row_1);
  85. std::vector<std::string> append_labels = {"1", "2", "3"};
  86. std::shared_ptr<Tensor> append;
  87. Tensor::CreateFromVector(append_labels, &append);
  88. TensorRow tensor_list;
  89. tensor_list.push_back(row_1);
  90. std::shared_ptr<Tensor> output;
  91. std::unique_ptr<ConcatenateOp> op(new ConcatenateOp(0, nullptr, append));
  92. TensorRow out_row;
  93. Status s = op->Compute(tensor_list, &out_row);
  94. std::vector<std::string> out = {"hello", "bye", "1", "2", "3"};
  95. std::shared_ptr<Tensor> expected;
  96. Tensor::CreateFromVector(out, &expected);
  97. output = out_row[0];
  98. EXPECT_TRUE(s.IsOk());
  99. ASSERT_TRUE(output->shape() == expected->shape());
  100. ASSERT_TRUE(output->type() == expected->type());
  101. MS_LOG(DEBUG) << *output << std::endl;
  102. MS_LOG(DEBUG) << *expected << std::endl;
  103. ASSERT_TRUE(*output == *expected);
  104. }