You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cache_op_test.cc 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <string>
  17. #include "minddata/dataset/core/client.h"
  18. #include "minddata/dataset/engine/cache/cache_client.h"
  19. #include "minddata/dataset/engine/execution_tree.h"
  20. #include "minddata/dataset/engine/datasetops/cache_op.h"
  21. #include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
  22. #include "minddata/dataset/engine/datasetops/cache_merge_op.h"
  23. #include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
  24. #include "common/common.h"
  25. #include "gtest/gtest.h"
  26. #include "utils/log_adapter.h"
  27. #include "minddata/dataset/util/storage_container.h" // lint !e322
  28. #include "minddata/dataset/engine/datasetops/source/random_data_op.h"
  29. #include "minddata/dataset/engine/data_schema.h"
  30. using namespace mindspore::dataset;
  31. using mindspore::LogStream;
  32. using mindspore::dataset::CacheClient;
  33. using mindspore::dataset::TaskGroup;
  34. using mindspore::ExceptionType::NoExceptionType;
  35. using mindspore::MsLogLevel::INFO;
  36. class MindDataTestCacheOp : public UT::DatasetOpTesting {
  37. public:
  38. void SetUp() override {
  39. DatasetOpTesting::SetUp();
  40. GlobalInit();
  41. }
  42. };
  43. TEST_F(MindDataTestCacheOp, TestCacheServer) {
  44. Status rc;
  45. CacheClient myClient(1, 0, true); // use arbitrary session of 1, size of 0, spilling is true
  46. // cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated.
  47. rc = myClient.CreateCache(1, true);
  48. EXPECT_TRUE(rc.IsOk());
  49. std::cout << myClient << std::endl;
  50. // Create a schema using the C api's
  51. int32_t rank = 0; // not used
  52. std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
  53. // 2 columns. First column is an "image" 640,480,3
  54. TensorShape c1Shape({640, 480, 3});
  55. ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
  56. rank, // not used
  57. &c1Shape);
  58. // Column 2 will just be a scalar label number
  59. TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
  60. ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
  61. testSchema->AddColumn(c1);
  62. testSchema->AddColumn(c2);
  63. std::unordered_map<std::string, int32_t> map;
  64. rc = testSchema->GetColumnNameMap(&map);
  65. EXPECT_TRUE(rc.IsOk());
  66. // Test the CacheSchema api
  67. rc = myClient.CacheSchema(map);
  68. EXPECT_TRUE(rc.IsOk());
  69. // Create a tensor, take a snapshot and restore it back, and compare.
  70. std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2, 3}), DataType(DataType::DE_UINT64));
  71. t->SetItemAt<uint64_t>({0, 0}, 1);
  72. t->SetItemAt<uint64_t>({0, 1}, 2);
  73. t->SetItemAt<uint64_t>({0, 2}, 3);
  74. t->SetItemAt<uint64_t>({1, 0}, 4);
  75. t->SetItemAt<uint64_t>({1, 1}, 5);
  76. t->SetItemAt<uint64_t>({1, 2}, 6);
  77. std::cout << *t << std::endl;
  78. TensorTable tbl;
  79. TensorRow row;
  80. row.push_back(t);
  81. int64_t row_id;
  82. rc = myClient.WriteRow(row, &row_id);
  83. EXPECT_TRUE(rc.IsOk());
  84. // Switch off build phase.
  85. rc = myClient.BuildPhaseDone();
  86. EXPECT_TRUE(rc.IsOk());
  87. // Now restore from cache.
  88. row.clear();
  89. rc = myClient.GetRows({row_id}, &tbl);
  90. row = tbl.front();
  91. EXPECT_TRUE(rc.IsOk());
  92. auto r = row.front();
  93. std::cout << *r << std::endl;
  94. // Compare
  95. bool cmp = (*t == *r);
  96. EXPECT_TRUE(cmp);
  97. // Get back the schema and verify
  98. std::unordered_map<std::string, int32_t> map_out;
  99. rc = myClient.FetchSchema(&map_out);
  100. EXPECT_TRUE(rc.IsOk());
  101. cmp = (map_out == map);
  102. EXPECT_TRUE(cmp);
  103. // Test Purge and Destroy
  104. rc = myClient.PurgeCache();
  105. EXPECT_TRUE(rc.IsOk());
  106. rc = myClient.DestroyCache();
  107. EXPECT_TRUE(rc.IsOk());
  108. }
  109. TEST_F(MindDataTestCacheOp, TestConcurrencyRequest) {
  110. // Clear the rc of the master thread if any
  111. (void)TaskManager::GetMasterThreadRc();
  112. TaskGroup vg;
  113. Status rc;
  114. CacheClient myClient(1, 1, true); // use arbitrary session of 1, size 1, spilling is true
  115. // cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated.
  116. rc = myClient.CreateCache(1, true);
  117. EXPECT_TRUE(rc.IsOk());
  118. std::cout << myClient << std::endl;
  119. std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2, 3}), DataType(DataType::DE_UINT64));
  120. t->SetItemAt<uint64_t>({0, 0}, 1);
  121. t->SetItemAt<uint64_t>({0, 1}, 2);
  122. t->SetItemAt<uint64_t>({0, 2}, 3);
  123. t->SetItemAt<uint64_t>({1, 0}, 4);
  124. t->SetItemAt<uint64_t>({1, 1}, 5);
  125. t->SetItemAt<uint64_t>({1, 2}, 6);
  126. TensorTable tbl;
  127. TensorRow row;
  128. row.push_back(t);
  129. // Cache tensor row t 5000 times using 10 threads.
  130. for (auto k = 0; k < 10; ++k) {
  131. Status vg_rc = vg.CreateAsyncTask("Test agent", [&myClient, &row]() -> Status {
  132. TaskManager::FindMe()->Post();
  133. for (auto i = 0; i < 500; i++) {
  134. RETURN_IF_NOT_OK(myClient.WriteRow(row));
  135. }
  136. return Status::OK();
  137. });
  138. EXPECT_TRUE(vg_rc.IsOk());
  139. }
  140. ASSERT_TRUE(vg.join_all().IsOk());
  141. ASSERT_TRUE(vg.GetTaskErrorIfAny().IsOk());
  142. rc = myClient.BuildPhaseDone();
  143. ASSERT_TRUE(rc.IsOk());
  144. // Get statistics from the server.
  145. CacheClient::ServiceStat stat{};
  146. rc = myClient.GetStat(&stat);
  147. ASSERT_TRUE(rc.IsOk());
  148. std::cout << stat.min_row_id << ":" << stat.max_row_id << ":" << stat.num_mem_cached << ":" << stat.num_disk_cached
  149. << "\n";
  150. // Expect there are 5000 rows there.
  151. EXPECT_EQ(5000, stat.max_row_id - stat.min_row_id + 1);
  152. // Get them all back using row id and compare with tensor t.
  153. for (auto i = stat.min_row_id; i <= stat.max_row_id; ++i) {
  154. tbl.clear();
  155. row.clear();
  156. rc = myClient.GetRows({i}, &tbl);
  157. EXPECT_TRUE(rc.IsOk());
  158. row = tbl.front();
  159. auto r = row.front();
  160. bool cmp = (*t == *r);
  161. EXPECT_TRUE(cmp);
  162. }
  163. rc = myClient.DestroyCache();
  164. EXPECT_TRUE(rc.IsOk());
  165. }
  166. // Simple test with a repeated cache op over random data producer
  167. //
  168. // RepeatOp
  169. // |
  170. // CacheOp
  171. // |
  172. // RandomDataOp
  173. //
  174. TEST_F(MindDataTestCacheOp, TestRandomDataCache1) {
  175. Status rc;
  176. int32_t rank = 0; // not used
  177. MS_LOG(INFO) << "UT test TestRandomDataCache1";
  178. // Start with an empty execution tree
  179. auto myTree = std::make_shared<ExecutionTree>();
  180. // Create a schema using the C api's
  181. std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
  182. // 2 columns. First column is an "image" 640,480,3
  183. TensorShape c1Shape({640, 480, 3});
  184. ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
  185. rank, // not used
  186. &c1Shape);
  187. // Column 2 will just be a scalar label number
  188. TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
  189. ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
  190. testSchema->AddColumn(c1);
  191. testSchema->AddColumn(c2);
  192. // RandomDataOp
  193. std::shared_ptr<RandomDataOp> myRandomDataOp;
  194. rc = RandomDataOp::Builder()
  195. .SetRowsPerBuffer(4)
  196. .SetNumWorkers(4)
  197. .SetDataSchema(std::move(testSchema))
  198. .SetTotalRows(50) // 50 samples for now
  199. .Build(&myRandomDataOp);
  200. EXPECT_TRUE(rc.IsOk());
  201. rc = myTree->AssociateNode(myRandomDataOp);
  202. EXPECT_TRUE(rc.IsOk());
  203. // CacheOp
  204. // size of 0, spilling is true
  205. std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 0, true);
  206. std::shared_ptr<CacheOp> myCacheOp;
  207. int64_t num_samples = 0;
  208. int64_t start_index = 0;
  209. auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
  210. rc = CacheOp::Builder()
  211. .SetNumWorkers(5)
  212. .SetClient(myClient)
  213. .SetRowsPerBuffer(4)
  214. .SetSampler(std::move(seq_sampler))
  215. .Build(&myCacheOp);
  216. EXPECT_TRUE(rc.IsOk());
  217. rc = myTree->AssociateNode(myCacheOp);
  218. EXPECT_TRUE(rc.IsOk());
  219. // RepeatOp
  220. uint32_t numRepeats = 4;
  221. std::shared_ptr<RepeatOp> myRepeatOp;
  222. rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
  223. EXPECT_TRUE(rc.IsOk());
  224. rc = myTree->AssociateNode(myRepeatOp);
  225. EXPECT_TRUE(rc.IsOk());
  226. // Assign tree relations and root
  227. rc = myRepeatOp->AddChild(myCacheOp);
  228. EXPECT_TRUE(rc.IsOk());
  229. rc = myCacheOp->AddChild(myRandomDataOp);
  230. EXPECT_TRUE(rc.IsOk());
  231. rc = myTree->AssignRoot(myRepeatOp);
  232. EXPECT_TRUE(rc.IsOk());
  233. MS_LOG(INFO) << "Launching tree and begin iteration";
  234. rc = myTree->Prepare();
  235. EXPECT_TRUE(rc.IsOk());
  236. // quick check to see what tree looks like
  237. std::ostringstream ss;
  238. ss << *myTree; // some funny const error if I try to write directly to ms log stream
  239. MS_LOG(INFO) << "Here's the tree:\n" << ss.str();
  240. std::cout << *myClient << std::endl;
  241. rc = myTree->Launch();
  242. EXPECT_TRUE(rc.IsOk());
  243. // Start the loop of reading tensors from our pipeline
  244. DatasetIterator dI(myTree);
  245. TensorRow tensorList;
  246. rc = dI.FetchNextTensorRow(&tensorList);
  247. EXPECT_TRUE(rc.IsOk());
  248. int rowCount = 0;
  249. while (!tensorList.empty()) {
  250. // Don't display these rows, just count them
  251. MS_LOG(INFO) << "Row fetched #: " << rowCount;
  252. rc = dI.FetchNextTensorRow(&tensorList);
  253. EXPECT_TRUE(rc.IsOk());
  254. rowCount++;
  255. }
  256. ASSERT_EQ(rowCount, 200);
  257. rc = myClient->DestroyCache();
  258. EXPECT_TRUE(rc.IsOk());
  259. }
  260. //// Simple test with a repeated cache op over random data producer.
  261. //// This one will exceed memory and require a spill.
  262. ////
  263. //// RepeatOp
  264. //// |
  265. //// CacheOp
  266. //// |
  267. //// RandomDataOp
  268. ////
  269. TEST_F(MindDataTestCacheOp, TestRandomDataCacheSpill) {
  270. Status rc;
  271. int32_t rank = 0; // not used
  272. MS_LOG(INFO) << "UT test TestRandomDataCacheSpill";
  273. // Start with an empty execution tree
  274. auto myTree = std::make_shared<ExecutionTree>();
  275. // Create a schema using the C api's
  276. std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
  277. // 2 columns. First column is an "image" 640,480,3
  278. TensorShape c1Shape({640, 480, 3});
  279. ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
  280. rank, // not used
  281. &c1Shape);
  282. // Column 2 will just be a scalar label number
  283. TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
  284. ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
  285. testSchema->AddColumn(c1);
  286. testSchema->AddColumn(c2);
  287. // RandomDataOp
  288. std::shared_ptr<RandomDataOp> myRandomDataOp;
  289. rc = RandomDataOp::Builder()
  290. .SetRowsPerBuffer(2)
  291. .SetNumWorkers(4)
  292. .SetDataSchema(std::move(testSchema))
  293. .SetTotalRows(10)
  294. .Build(&myRandomDataOp);
  295. EXPECT_TRUE(rc.IsOk());
  296. rc = myTree->AssociateNode(myRandomDataOp);
  297. EXPECT_TRUE(rc.IsOk());
  298. // CacheOp
  299. int64_t num_samples = 0;
  300. int64_t start_index = 0;
  301. auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
  302. std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 4, true);
  303. std::shared_ptr<CacheOp> myCacheOp;
  304. rc = CacheOp::Builder()
  305. .SetNumWorkers(4)
  306. .SetClient(myClient)
  307. .SetRowsPerBuffer(3)
  308. .SetSampler(std::move(seq_sampler))
  309. .Build(&myCacheOp);
  310. EXPECT_TRUE(rc.IsOk());
  311. rc = myTree->AssociateNode(myCacheOp);
  312. EXPECT_TRUE(rc.IsOk());
  313. // RepeatOp
  314. uint32_t numRepeats = 4;
  315. std::shared_ptr<RepeatOp> myRepeatOp;
  316. rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
  317. EXPECT_TRUE(rc.IsOk());
  318. rc = myTree->AssociateNode(myRepeatOp);
  319. EXPECT_TRUE(rc.IsOk());
  320. // Assign tree relations and root
  321. rc = myRepeatOp->AddChild(myCacheOp);
  322. EXPECT_TRUE(rc.IsOk());
  323. rc = myCacheOp->AddChild(myRandomDataOp);
  324. EXPECT_TRUE(rc.IsOk());
  325. rc = myTree->AssignRoot(myRepeatOp);
  326. EXPECT_TRUE(rc.IsOk());
  327. MS_LOG(INFO) << "Launching tree and begin iteration";
  328. rc = myTree->Prepare();
  329. EXPECT_TRUE(rc.IsOk());
  330. std::cout << *myClient << std::endl;
  331. rc = myTree->Launch();
  332. EXPECT_TRUE(rc.IsOk());
  333. // Start the loop of reading tensors from our pipeline
  334. DatasetIterator dI(myTree);
  335. TensorRow tensorList;
  336. rc = dI.FetchNextTensorRow(&tensorList);
  337. EXPECT_TRUE(rc.IsOk());
  338. int rowCount = 0;
  339. while (!tensorList.empty()) {
  340. // Don't display these rows, just count them
  341. MS_LOG(INFO) << "Row fetched #: " << rowCount;
  342. rc = dI.FetchNextTensorRow(&tensorList);
  343. EXPECT_TRUE(rc.IsOk());
  344. rowCount++;
  345. }
  346. ASSERT_EQ(rowCount, 40);
  347. rc = myClient->DestroyCache();
  348. EXPECT_TRUE(rc.IsOk());
  349. }
  350. TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) {
  351. Status rc;
  352. int64_t num_samples = 0;
  353. int64_t start_index = 0;
  354. auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
  355. std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 0, true);
  356. std::shared_ptr<CacheMergeOp> myMergeOp;
  357. rc = CacheMergeOp::Builder().SetNumWorkers(3).SetOpConnectorSize(3).SetNumCleaner(2).SetClient(myClient).Build(
  358. &myMergeOp);
  359. EXPECT_TRUE(rc.IsOk());
  360. std::shared_ptr<CacheLookupOp> myLookupOp;
  361. rc = CacheLookupOp::Builder()
  362. .SetNumWorkers(3)
  363. .SetOpConnectorSize(3)
  364. .SetClient(myClient)
  365. .SetSampler(seq_sampler)
  366. .Build(&myLookupOp);
  367. EXPECT_TRUE(rc.IsOk());
  368. std::shared_ptr<ImageFolderOp> so;
  369. ImageFolderOp::Builder builder;
  370. builder.SetSampler(myLookupOp)
  371. .SetOpConnectorSize(3)
  372. .SetNumWorkers(3)
  373. .SetRowsPerBuffer(2)
  374. .SetExtensions({".jpg", ".JPEG"})
  375. .SetRecursive(true)
  376. .SetImageFolderDir(datasets_root_path_ + "/testPK/data");
  377. rc = builder.Build(&so);
  378. EXPECT_TRUE(rc.IsOk());
  379. // RepeatOp
  380. uint32_t numRepeats = 4;
  381. std::shared_ptr<RepeatOp> myRepeatOp;
  382. rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
  383. EXPECT_TRUE(rc.IsOk());
  384. auto myTree = std::make_shared<ExecutionTree>();
  385. rc = myTree->AssociateNode(so);
  386. EXPECT_TRUE(rc.IsOk());
  387. rc = myTree->AssociateNode(myLookupOp);
  388. EXPECT_TRUE(rc.IsOk());
  389. rc = myTree->AssociateNode(myMergeOp);
  390. EXPECT_TRUE(rc.IsOk());
  391. rc = myTree->AssociateNode(myRepeatOp);
  392. EXPECT_TRUE(rc.IsOk());
  393. rc = myTree->AssignRoot(myRepeatOp);
  394. EXPECT_TRUE(rc.IsOk());
  395. rc = myRepeatOp->AddChild(myMergeOp);
  396. EXPECT_TRUE(rc.IsOk());
  397. rc = myMergeOp->AddChild(myLookupOp);
  398. EXPECT_TRUE(rc.IsOk());
  399. rc = myMergeOp->AddChild(so);
  400. EXPECT_TRUE(rc.IsOk());
  401. rc = myTree->Prepare();
  402. EXPECT_TRUE(rc.IsOk());
  403. rc = myTree->Launch();
  404. EXPECT_TRUE(rc.IsOk());
  405. // Start the loop of reading tensors from our pipeline
  406. DatasetIterator dI(myTree);
  407. TensorRow tensorList;
  408. rc = dI.FetchNextTensorRow(&tensorList);
  409. EXPECT_TRUE(rc.IsOk());
  410. int rowCount = 0;
  411. while (!tensorList.empty()) {
  412. rc = dI.FetchNextTensorRow(&tensorList);
  413. EXPECT_TRUE(rc.IsOk());
  414. if (rc.IsError()) {
  415. std::cout << rc << std::endl;
  416. break;
  417. }
  418. rowCount++;
  419. }
  420. ASSERT_EQ(rowCount, 176);
  421. std::cout << "Row count : " << rowCount << std::endl;
  422. rc = myClient->DestroyCache();
  423. EXPECT_TRUE(rc.IsOk());
  424. }
  425. //// Simple test with a repeated cache op over random data producer.
  426. //// The difference in this one is that you do not add the sampler to the cache op directly.
  427. //// Instead, the sampler is added as part of the leaf op construction. Then, the prepare
  428. //// phase will pull this up from the leaf and into the cache.
  429. //// It removes the sampler from the leaf op, which doesn't make sense there anyway for
  430. //// the RandomDataOp which doesn't support sampling without a cache.
  431. ////
  432. //// RepeatOp
  433. //// |
  434. //// CacheOp
  435. //// |
  436. //// RandomDataOp
  437. ////
  438. TEST_F(MindDataTestCacheOp, TestCacheInheritSampler) {
  439. Status rc;
  440. int32_t rank = 0; // not used
  441. MS_LOG(INFO) << "UT test TestCacheInheritSampler";
  442. int64_t num_samples = 0;
  443. int64_t start_index = 0;
  444. auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
  445. // Start with an empty execution tree
  446. auto myTree = std::make_shared<ExecutionTree>();
  447. // Create a schema using the C api's
  448. std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
  449. // 2 columns. First column is an "image" 640,480,3
  450. TensorShape c1Shape({640, 480, 3});
  451. ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
  452. rank, // not used
  453. &c1Shape);
  454. // Column 2 will just be a scalar label number
  455. TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
  456. ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
  457. testSchema->AddColumn(c1);
  458. testSchema->AddColumn(c2);
  459. // RandomDataOp
  460. std::shared_ptr<RandomDataOp> myRandomDataOp;
  461. rc = RandomDataOp::Builder()
  462. .SetRowsPerBuffer(2)
  463. .SetNumWorkers(4)
  464. .SetDataSchema(std::move(testSchema))
  465. .SetTotalRows(10)
  466. .SetSampler(std::move(seq_sampler))
  467. .Build(&myRandomDataOp);
  468. EXPECT_TRUE(rc.IsOk());
  469. rc = myTree->AssociateNode(myRandomDataOp);
  470. EXPECT_TRUE(rc.IsOk());
  471. // CacheOp
  472. std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 4, true);
  473. std::shared_ptr<CacheOp> myCacheOp;
  474. rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp);
  475. EXPECT_TRUE(rc.IsOk());
  476. rc = myTree->AssociateNode(myCacheOp);
  477. EXPECT_TRUE(rc.IsOk());
  478. // RepeatOp
  479. uint32_t numRepeats = 4;
  480. std::shared_ptr<RepeatOp> myRepeatOp;
  481. rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
  482. EXPECT_TRUE(rc.IsOk());
  483. rc = myTree->AssociateNode(myRepeatOp);
  484. EXPECT_TRUE(rc.IsOk());
  485. // Assign tree relations and root
  486. rc = myRepeatOp->AddChild(myCacheOp);
  487. EXPECT_TRUE(rc.IsOk());
  488. rc = myCacheOp->AddChild(myRandomDataOp);
  489. EXPECT_TRUE(rc.IsOk());
  490. rc = myTree->AssignRoot(myRepeatOp);
  491. EXPECT_TRUE(rc.IsOk());
  492. MS_LOG(INFO) << "Launching tree and begin iteration";
  493. rc = myTree->Prepare();
  494. EXPECT_TRUE(rc.IsOk());
  495. std::cout << *myClient << std::endl;
  496. rc = myTree->Launch();
  497. EXPECT_TRUE(rc.IsOk());
  498. // Start the loop of reading tensors from our pipeline
  499. DatasetIterator dI(myTree);
  500. TensorRow tensorList;
  501. rc = dI.FetchNextTensorRow(&tensorList);
  502. EXPECT_TRUE(rc.IsOk());
  503. int rowCount = 0;
  504. while (!tensorList.empty()) {
  505. // Don't display these rows, just count them
  506. MS_LOG(INFO) << "Row fetched #: " << rowCount;
  507. rc = dI.FetchNextTensorRow(&tensorList);
  508. EXPECT_TRUE(rc.IsOk());
  509. rowCount++;
  510. }
  511. ASSERT_EQ(rowCount, 40);
  512. rc = myClient->DestroyCache();
  513. EXPECT_TRUE(rc.IsOk());
  514. }