You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batch_op_test.cc 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <memory>
  18. #include <string>
  19. #include "minddata/dataset/core/client.h"
  20. #include "common/common.h"
  21. #include "utils/ms_utils.h"
  22. #include "gtest/gtest.h"
  23. #include "minddata/dataset/core/global_context.h"
  24. #include "utils/log_adapter.h"
  25. #include "securec.h"
  26. #include "minddata/dataset/util/status.h"
  27. namespace common = mindspore::common;
  28. namespace de = mindspore::dataset;
  29. using namespace mindspore::dataset;
  30. using mindspore::LogStream;
  31. using mindspore::ExceptionType::NoExceptionType;
  32. using mindspore::MsLogLevel::ERROR;
  33. class MindDataTestBatchOp : public UT::DatasetOpTesting {
  34. protected:
  35. };
  36. std::shared_ptr<de::BatchOp> Batch(int32_t batch_size = 1, bool drop = false, int rows_per_buf = 2) {
  37. Status rc;
  38. std::shared_ptr<de::BatchOp> op;
  39. rc = de::BatchOp::Builder(batch_size).SetDrop(drop).Build(&op);
  40. EXPECT_TRUE(rc.IsOk());
  41. return op;
  42. }
  43. std::shared_ptr<de::RepeatOp> Repeat(int repeat_cnt = 1) {
  44. de::RepeatOp::Builder builder(repeat_cnt);
  45. std::shared_ptr<de::RepeatOp> op;
  46. Status rc = builder.Build(&op);
  47. EXPECT_TRUE(rc.IsOk());
  48. return op;
  49. }
  50. std::shared_ptr<de::TFReaderOp> TFReader(std::string schema, int rows_per_buf = 2, int num_works = 8) {
  51. std::shared_ptr<de::TFReaderOp> so;
  52. de::TFReaderOp::Builder builder;
  53. builder.SetDatasetFilesList({schema}).SetRowsPerBuffer(rows_per_buf).SetNumWorkers(num_works);
  54. Status rc = builder.Build(&so);
  55. return so;
  56. }
  57. std::shared_ptr<de::ExecutionTree> Build(std::vector<std::shared_ptr<de::DatasetOp>> ops) {
  58. std::shared_ptr<de::ExecutionTree> tree = std::make_shared<de::ExecutionTree>();
  59. for (int i = 0; i < ops.size(); i++) {
  60. tree->AssociateNode(ops[i]);
  61. if (i > 0) {
  62. ops[i]->AddChild(ops[i - 1]);
  63. }
  64. if (i == ops.size() - 1) {
  65. tree->AssignRoot(ops[i]);
  66. }
  67. }
  68. return tree;
  69. }
  70. TEST_F(MindDataTestBatchOp, TestSimpleBatch) {
  71. std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
  72. bool success = false;
  73. const std::shared_ptr<de::BatchOp> &op = Batch(12);
  74. EXPECT_EQ(op->Name(), "BatchOp");
  75. auto tree = Build({TFReader(schema_file), op});
  76. tree->Prepare();
  77. Status rc = tree->Launch();
  78. if (rc.IsError()) {
  79. MS_LOG(ERROR) << "Return code error detected during tree launch: " << rc.ToString() << ".";
  80. } else {
  81. int64_t payload[] = {-9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807};
  82. de::DatasetIterator di(tree);
  83. TensorMap tensor_map;
  84. rc = di.GetNextAsMap(&tensor_map);
  85. EXPECT_TRUE(rc.IsOk());
  86. std::shared_ptr<de::Tensor> t;
  87. rc = de::Tensor::CreateFromMemory(de::TensorShape({12, 1}), de::DataType(DataType::DE_INT64),
  88. (unsigned char *)payload, &t);
  89. EXPECT_TRUE(rc.IsOk());
  90. // verify the actual data in Tensor is correct
  91. EXPECT_EQ(*t == *tensor_map["col_sint64"], true);
  92. // change what's in Tensor and verify this time the data is incorrect1;
  93. EXPECT_EQ(*t == *tensor_map["col_sint16"], false);
  94. rc = di.GetNextAsMap(&tensor_map);
  95. EXPECT_TRUE(rc.IsOk());
  96. if (tensor_map.size() == 0) {
  97. success = true;
  98. }
  99. }
  100. EXPECT_EQ(success, true);
  101. }
  102. TEST_F(MindDataTestBatchOp, TestRepeatBatchDropTrue) {
  103. std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
  104. bool success = false;
  105. auto tree = Build({TFReader(schema_file), Repeat(2), Batch(7, true, 99)});
  106. tree->Prepare();
  107. Status rc = tree->Launch();
  108. if (rc.IsError()) {
  109. MS_LOG(ERROR) << "Return code error detected during tree launch: " << rc.ToString() << ".";
  110. } else {
  111. int64_t payload[] = {-9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807,
  112. -9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807};
  113. de::DatasetIterator di(tree);
  114. std::shared_ptr<de::Tensor> t1, t2, t3;
  115. rc = de::Tensor::CreateFromMemory(de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64),
  116. (unsigned char *)payload, &t1);
  117. EXPECT_TRUE(rc.IsOk());
  118. rc = de::Tensor::CreateFromMemory(de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64),
  119. (unsigned char *)(payload + 7), &t2);
  120. EXPECT_TRUE(rc.IsOk());
  121. rc = de::Tensor::CreateFromMemory(de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64),
  122. (unsigned char *)(payload + 2), &t3);
  123. EXPECT_TRUE(rc.IsOk());
  124. TensorMap tensor_map;
  125. rc = di.GetNextAsMap(&tensor_map);
  126. EXPECT_TRUE(rc.IsOk());
  127. EXPECT_EQ(*t1 == *(tensor_map["col_sint64"]), true); // first call to getNext()
  128. rc = di.GetNextAsMap(&tensor_map);
  129. EXPECT_TRUE(rc.IsOk());
  130. EXPECT_EQ(*t2 == *(tensor_map["col_sint64"]), true); // second call to getNext()
  131. rc = di.GetNextAsMap(&tensor_map);
  132. EXPECT_TRUE(rc.IsOk());
  133. EXPECT_EQ(*t3 == *(tensor_map["col_sint64"]), true); // third call to getNext()
  134. rc = di.GetNextAsMap(&tensor_map);
  135. EXPECT_TRUE(rc.IsOk());
  136. if (tensor_map.size() == 0) {
  137. success = true;
  138. }
  139. }
  140. EXPECT_EQ(success, true);
  141. }
  142. TEST_F(MindDataTestBatchOp, TestRepeatBatchDropFalse) {
  143. std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
  144. bool success = false;
  145. auto tree = Build({TFReader(schema_file), Repeat(2), Batch(7, false, 99)});
  146. tree->Prepare();
  147. Status rc = tree->Launch();
  148. if (rc.IsError()) {
  149. MS_LOG(ERROR) << "Return code error detected during tree launch: " << rc.ToString() << ".";
  150. } else {
  151. int64_t payload[] = {-9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807,
  152. -9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807};
  153. de::DatasetIterator di(tree);
  154. std::shared_ptr<de::Tensor> t1, t2, t3, t4;
  155. rc = de::Tensor::CreateFromMemory(de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64),
  156. (unsigned char *)payload, &t1);
  157. EXPECT_TRUE(rc.IsOk());
  158. rc = de::Tensor::CreateFromMemory(de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64),
  159. (unsigned char *)(payload + 7), &t2);
  160. EXPECT_TRUE(rc.IsOk());
  161. rc = de::Tensor::CreateFromMemory(de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64),
  162. (unsigned char *)(payload + 2), &t3);
  163. EXPECT_TRUE(rc.IsOk());
  164. rc = de::Tensor::CreateFromMemory(de::TensorShape({3, 1}), de::DataType(DataType::DE_INT64),
  165. (unsigned char *)(payload + 9), &t4);
  166. EXPECT_TRUE(rc.IsOk());
  167. TensorMap tensor_map;
  168. rc = di.GetNextAsMap(&tensor_map);
  169. EXPECT_TRUE(rc.IsOk());
  170. EXPECT_EQ(*t1 == *(tensor_map["col_sint64"]), true); // first call to getNext()
  171. rc = di.GetNextAsMap(&tensor_map);
  172. EXPECT_TRUE(rc.IsOk());
  173. EXPECT_EQ(*t2 == *(tensor_map["col_sint64"]), true); // second call to getNext()
  174. rc = di.GetNextAsMap(&tensor_map);
  175. EXPECT_TRUE(rc.IsOk());
  176. EXPECT_EQ(*t3 == *(tensor_map["col_sint64"]), true); // third call to getNext()
  177. rc = di.GetNextAsMap(&tensor_map);
  178. EXPECT_TRUE(rc.IsOk());
  179. EXPECT_EQ(*t4 == *(tensor_map["col_sint64"]), true); // last call to getNext()
  180. rc = di.GetNextAsMap(&tensor_map);
  181. EXPECT_TRUE(rc.IsOk());
  182. if (tensor_map.size() == 0) {
  183. success = true;
  184. }
  185. }
  186. EXPECT_EQ(success, true);
  187. }
  188. TEST_F(MindDataTestBatchOp, TestBatchDropFalseRepeat) {
  189. std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
  190. bool success = false;
  191. auto tree = Build({TFReader(schema_file), Batch(7, false, 99), Repeat(2)});
  192. tree->Prepare();
  193. Status rc = tree->Launch();
  194. if (rc.IsError()) {
  195. MS_LOG(ERROR) << "Return code error detected during tree launch: " << rc.ToString() << ".";
  196. } else {
  197. int64_t payload[] = {-9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807,
  198. -9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807};
  199. de::DatasetIterator di(tree);
  200. std::shared_ptr<de::Tensor> t1, t2;
  201. rc = de::Tensor::CreateFromMemory(de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64),
  202. (unsigned char *)payload, &t1);
  203. EXPECT_TRUE(rc.IsOk());
  204. rc = de::Tensor::CreateFromMemory(de::TensorShape({5, 1}), de::DataType(DataType::DE_INT64),
  205. (unsigned char *)(payload + 7), &t2);
  206. EXPECT_TRUE(rc.IsOk());
  207. TensorMap tensor_map;
  208. rc = di.GetNextAsMap(&tensor_map);
  209. EXPECT_TRUE(rc.IsOk());
  210. EXPECT_EQ(*t1 == *(tensor_map["col_sint64"]), true); // first call to getNext()
  211. rc = di.GetNextAsMap(&tensor_map);
  212. EXPECT_TRUE(rc.IsOk());
  213. EXPECT_EQ(*t2 == *(tensor_map["col_sint64"]), true); // second call to getNext()
  214. rc = di.GetNextAsMap(&tensor_map);
  215. EXPECT_TRUE(rc.IsOk());
  216. EXPECT_EQ(*t1 == *(tensor_map["col_sint64"]), true); // third call to getNext()
  217. rc = di.GetNextAsMap(&tensor_map);
  218. EXPECT_TRUE(rc.IsOk());
  219. EXPECT_EQ(*t2 == *(tensor_map["col_sint64"]), true); // last call to getNext()
  220. rc = di.GetNextAsMap(&tensor_map);
  221. EXPECT_TRUE(rc.IsOk());
  222. if (tensor_map.size() == 0) {
  223. success = true;
  224. }
  225. }
  226. EXPECT_EQ(success, true);
  227. }
  228. TEST_F(MindDataTestBatchOp, TestBatchDropTrueRepeat) {
  229. std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
  230. bool success = false;
  231. auto tree = Build({TFReader(schema_file), Batch(5, true, 99), Repeat(2)});
  232. tree->Prepare();
  233. Status rc = tree->Launch();
  234. if (rc.IsError()) {
  235. MS_LOG(ERROR) << "Return code error detected during tree launch: " << rc.ToString() << ".";
  236. } else {
  237. int64_t payload[] = {-9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807,
  238. -9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807};
  239. de::DatasetIterator di(tree);
  240. std::shared_ptr<de::Tensor> t1, t2;
  241. rc = de::Tensor::CreateFromMemory(de::TensorShape({5, 1}), de::DataType(DataType::DE_INT64),
  242. (unsigned char *)payload, &t1);
  243. EXPECT_TRUE(rc.IsOk());
  244. rc = de::Tensor::CreateFromMemory(de::TensorShape({5, 1}), de::DataType(DataType::DE_INT64),
  245. (unsigned char *)(payload + 5), &t2);
  246. EXPECT_TRUE(rc.IsOk());
  247. TensorMap tensor_map;
  248. rc = di.GetNextAsMap(&tensor_map);
  249. EXPECT_TRUE(rc.IsOk());
  250. EXPECT_EQ(*t1 == *(tensor_map["col_sint64"]), true); // first call to getNext()
  251. rc = di.GetNextAsMap(&tensor_map);
  252. EXPECT_TRUE(rc.IsOk());
  253. EXPECT_EQ(*t2 == *(tensor_map["col_sint64"]), true); // second call to getNext()
  254. rc = di.GetNextAsMap(&tensor_map);
  255. EXPECT_TRUE(rc.IsOk());
  256. EXPECT_EQ(*t1 == *(tensor_map["col_sint64"]), true); // third call to getNext()
  257. rc = di.GetNextAsMap(&tensor_map);
  258. EXPECT_TRUE(rc.IsOk());
  259. EXPECT_EQ(*t2 == *(tensor_map["col_sint64"]), true); // last call to getNext()
  260. rc = di.GetNextAsMap(&tensor_map);
  261. EXPECT_TRUE(rc.IsOk());
  262. if (tensor_map.size() == 0) {
  263. success = true;
  264. }
  265. }
  266. EXPECT_EQ(success, true);
  267. }
  268. TEST_F(MindDataTestBatchOp, TestSimpleBatchPadding) {
  269. std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
  270. std::shared_ptr<BatchOp> op;
  271. PadInfo m;
  272. std::shared_ptr<Tensor> pad_value;
  273. Tensor::CreateEmpty(TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32), &pad_value);
  274. pad_value->SetItemAt<float>({}, -1);
  275. m.insert({"col_1d", std::make_pair(TensorShape({4}), pad_value)});
  276. de::BatchOp::Builder(12).SetDrop(false).SetPaddingMap(m, true).Build(&op);
  277. auto tree = Build({TFReader(schema_file), op});
  278. tree->Prepare();
  279. Status rc = tree->Launch();
  280. if (rc.IsError()) {
  281. MS_LOG(ERROR) << "Return code error detected during tree launch: " << rc.ToString() << ".";
  282. } else {
  283. int64_t payload[] = {-9223372036854775807 - 1,
  284. 1,
  285. -1,
  286. -1,
  287. 2,
  288. 3,
  289. -1,
  290. -1,
  291. 4,
  292. 5,
  293. -1,
  294. -1,
  295. 6,
  296. 7,
  297. -1,
  298. -1,
  299. 8,
  300. 9,
  301. -1,
  302. -1,
  303. 10,
  304. 11,
  305. -1,
  306. -1,
  307. 12,
  308. 13,
  309. -1,
  310. -1,
  311. 14,
  312. 15,
  313. -1,
  314. -1,
  315. 16,
  316. 17,
  317. -1,
  318. -1,
  319. 18,
  320. 19,
  321. -1,
  322. -1,
  323. 20,
  324. 21,
  325. -1,
  326. -1,
  327. 22,
  328. 23,
  329. -1,
  330. -1};
  331. std::shared_ptr<de::Tensor> t;
  332. rc = de::Tensor::CreateFromMemory(de::TensorShape({12, 4}), de::DataType(DataType::DE_INT64),
  333. (unsigned char *)payload, &t);
  334. de::DatasetIterator di(tree);
  335. TensorMap tensor_map;
  336. rc = di.GetNextAsMap(&tensor_map);
  337. EXPECT_TRUE((*t) == (*(tensor_map["col_1d"])));
  338. rc = di.GetNextAsMap(&tensor_map);
  339. EXPECT_TRUE(tensor_map.size() == 0);
  340. EXPECT_TRUE(rc.IsOk());
  341. }
  342. }