You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batch_op_test.cc 5.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <memory>
  17. #include <string>
  18. #include "minddata/dataset/core/client.h"
  19. // #include "minddata/dataset/core/pybind_support.h"
  20. // #include "minddata/dataset/core/tensor.h"
  21. // #include "minddata/dataset/core/tensor_shape.h"
  22. // #include "minddata/dataset/engine/datasetops/batch_op.h"
  23. #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
  24. #include "common/common.h"
  25. #include "gtest/gtest.h"
  26. #include "utils/log_adapter.h"
  27. #include "securec.h"
  28. #include "minddata/dataset/util/status.h"
  29. // #include "pybind11/numpy.h"
  30. // #include "pybind11/pybind11.h"
  31. // #include "utils/ms_utils.h"
  32. // #include "minddata/dataset/engine/db_connector.h"
  33. // #include "minddata/dataset/kernels/data/data_utils.h"
  34. namespace common = mindspore::common;
  35. namespace de = mindspore::dataset;
  36. using namespace mindspore::dataset;
  37. using mindspore::LogStream;
  38. using mindspore::ExceptionType::NoExceptionType;
  39. using mindspore::MsLogLevel::ERROR;
  40. class MindDataTestBatchOp : public UT::DatasetOpTesting {
  41. protected:
  42. };
  43. // This test has been disabled because PadInfo is not currently supported in the C++ API.
  44. // Feature: Test Batch op with padding on TFReader
  45. // Description: Create Batch operation with padding on a TFReader dataset
  46. // Expectation: The data within the created object should match the expected data
  47. TEST_F(MindDataTestBatchOp, DISABLED_TestSimpleBatchPadding) {
  48. std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
  49. PadInfo m;
  50. std::shared_ptr<Tensor> pad_value;
  51. Tensor::CreateEmpty(TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32), &pad_value);
  52. pad_value->SetItemAt<float>({}, -1);
  53. m.insert({"col_1d", std::make_pair(TensorShape({4}), pad_value)});
  54. /*
  55. std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
  56. auto op_connector_size = config_manager->op_connector_size();
  57. auto num_workers = config_manager->num_parallel_workers();
  58. std::vector<std::string> input_columns = {};
  59. std::vector<std::string> output_columns = {};
  60. pybind11::function batch_size_func;
  61. pybind11::function batch_map_func;
  62. */
  63. int32_t batch_size = 12;
  64. bool drop = false;
  65. std::shared_ptr<BatchOp> op = Batch(batch_size, drop, m);
  66. // std::make_shared<BatchOp>(batch_size, drop, pad, op_connector_size, num_workers, input_columns, output_columns,
  67. // batch_size_func, batch_map_func, m);
  68. auto tree = Build({TFReader(schema_file), op});
  69. tree->Prepare();
  70. Status rc = tree->Launch();
  71. if (rc.IsError()) {
  72. MS_LOG(ERROR) << "Return code error detected during tree launch: " << rc.ToString() << ".";
  73. } else {
  74. int64_t payload[] = {-9223372036854775807 - 1,
  75. 1,
  76. -1,
  77. -1,
  78. 2,
  79. 3,
  80. -1,
  81. -1,
  82. 4,
  83. 5,
  84. -1,
  85. -1,
  86. 6,
  87. 7,
  88. -1,
  89. -1,
  90. 8,
  91. 9,
  92. -1,
  93. -1,
  94. 10,
  95. 11,
  96. -1,
  97. -1,
  98. 12,
  99. 13,
  100. -1,
  101. -1,
  102. 14,
  103. 15,
  104. -1,
  105. -1,
  106. 16,
  107. 17,
  108. -1,
  109. -1,
  110. 18,
  111. 19,
  112. -1,
  113. -1,
  114. 20,
  115. 21,
  116. -1,
  117. -1,
  118. 22,
  119. 23,
  120. -1,
  121. -1};
  122. std::shared_ptr<de::Tensor> t;
  123. rc = de::Tensor::CreateFromMemory(de::TensorShape({12, 4}), de::DataType(DataType::DE_INT64),
  124. (unsigned char *)payload, &t);
  125. de::DatasetIterator di(tree);
  126. TensorMap tensor_map;
  127. rc = di.GetNextAsMap(&tensor_map);
  128. EXPECT_TRUE((*t) == (*(tensor_map["col_1d"])));
  129. rc = di.GetNextAsMap(&tensor_map);
  130. EXPECT_TRUE(tensor_map.size() == 0);
  131. EXPECT_TRUE(rc.IsOk());
  132. }
  133. }