You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cache_op_test.cc 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <string>
  17. #include "minddata/dataset/core/client.h"
  18. #include "minddata/dataset/engine/cache/cache_client.h"
  19. #include "minddata/dataset/engine/execution_tree.h"
  20. #include "minddata/dataset/engine/datasetops/cache_op.h"
  21. #include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
  22. #include "minddata/dataset/engine/datasetops/cache_merge_op.h"
  23. #include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
  24. #include "common/common.h"
  25. #include "gtest/gtest.h"
  26. #include "utils/log_adapter.h"
  27. #include "minddata/dataset/util/storage_container.h" // lint !e322
  28. #include "minddata/dataset/engine/datasetops/source/random_data_op.h"
  29. #include "minddata/dataset/engine/data_schema.h"
  30. using namespace mindspore::dataset;
  31. using mindspore::LogStream;
  32. using mindspore::dataset::CacheClient;
  33. using mindspore::dataset::TaskGroup;
  34. using mindspore::ExceptionType::NoExceptionType;
  35. using mindspore::MsLogLevel::INFO;
  36. class MindDataTestCacheOp : public UT::DatasetOpTesting {
  37. public:
  38. void SetUp() override {
  39. DatasetOpTesting::SetUp();
  40. GlobalInit();
  41. }
  42. };
  43. TEST_F(MindDataTestCacheOp, DISABLED_TestCacheServer) {
  44. Status rc;
  45. CacheClient::Builder builder;
  46. // use arbitrary session of 1, size of 0, spilling// is true
  47. builder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true);
  48. std::shared_ptr<CacheClient> myClient;
  49. rc = builder.Build(&myClient);
  50. ASSERT_TRUE(rc.IsOk());
  51. // cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated.
  52. rc = myClient->CreateCache(1, true);
  53. ASSERT_TRUE(rc.IsOk());
  54. std::cout << *myClient << std::endl;
  55. // Create a schema using the C api's
  56. int32_t rank = 0; // not used
  57. std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
  58. // 2 columns. First column is an "image" 640,480,3
  59. TensorShape c1Shape({640, 480, 3});
  60. ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
  61. rank, // not used
  62. &c1Shape);
  63. // Column 2 will just be a scalar label number
  64. TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
  65. ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
  66. testSchema->AddColumn(c1);
  67. testSchema->AddColumn(c2);
  68. std::unordered_map<std::string, int32_t> map;
  69. rc = testSchema->GetColumnNameMap(&map);
  70. ASSERT_TRUE(rc.IsOk());
  71. // Test the CacheSchema api
  72. rc = myClient->CacheSchema(map);
  73. ASSERT_TRUE(rc.IsOk());
  74. // Create a tensor, take a snapshot and restore it back, and compare.
  75. std::shared_ptr<Tensor> t;
  76. Tensor::CreateEmpty(TensorShape({2, 3}), DataType(DataType::DE_UINT64), &t);
  77. t->SetItemAt<uint64_t>({0, 0}, 1);
  78. t->SetItemAt<uint64_t>({0, 1}, 2);
  79. t->SetItemAt<uint64_t>({0, 2}, 3);
  80. t->SetItemAt<uint64_t>({1, 0}, 4);
  81. t->SetItemAt<uint64_t>({1, 1}, 5);
  82. t->SetItemAt<uint64_t>({1, 2}, 6);
  83. std::cout << *t << std::endl;
  84. TensorTable tbl;
  85. TensorRow row;
  86. row.push_back(t);
  87. int64_t row_id;
  88. rc = myClient->WriteRow(row, &row_id);
  89. ASSERT_TRUE(rc.IsOk());
  90. // Switch off build phase.
  91. rc = myClient->BuildPhaseDone();
  92. ASSERT_TRUE(rc.IsOk());
  93. // Now restore from cache.
  94. row.clear();
  95. rc = myClient->GetRows({row_id}, &tbl);
  96. row = tbl.front();
  97. ASSERT_TRUE(rc.IsOk());
  98. auto r = row.front();
  99. std::cout << *r << std::endl;
  100. // Compare
  101. bool cmp = (*t == *r);
  102. ASSERT_TRUE(cmp);
  103. // Get back the schema and verify
  104. std::unordered_map<std::string, int32_t> map_out;
  105. rc = myClient->FetchSchema(&map_out);
  106. ASSERT_TRUE(rc.IsOk());
  107. cmp = (map_out == map);
  108. ASSERT_TRUE(cmp);
  109. // Test Purge and Destroy
  110. rc = myClient->PurgeCache();
  111. ASSERT_TRUE(rc.IsOk());
  112. rc = myClient->DestroyCache();
  113. ASSERT_TRUE(rc.IsOk());
  114. }
  115. TEST_F(MindDataTestCacheOp, DISABLED_TestConcurrencyRequest) {
  116. // Clear the rc of the master thread if any
  117. (void)TaskManager::GetMasterThreadRc();
  118. TaskGroup vg;
  119. Status rc;
  120. // use arbitrary session of 1, size 1, spilling is true
  121. CacheClient::Builder builder;
  122. // use arbitrary session of 1, size of 0, spilling// is true
  123. builder.SetSessionId(1).SetCacheMemSz(1).SetSpill(true);
  124. std::shared_ptr<CacheClient> myClient;
  125. rc = builder.Build(&myClient);
  126. ASSERT_TRUE(rc.IsOk());
  127. // cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated.
  128. rc = myClient->CreateCache(1, true);
  129. ASSERT_TRUE(rc.IsOk());
  130. std::cout << *myClient << std::endl;
  131. std::shared_ptr<Tensor> t;
  132. Tensor::CreateEmpty(TensorShape({2, 3}), DataType(DataType::DE_UINT64), &t);
  133. t->SetItemAt<uint64_t>({0, 0}, 1);
  134. t->SetItemAt<uint64_t>({0, 1}, 2);
  135. t->SetItemAt<uint64_t>({0, 2}, 3);
  136. t->SetItemAt<uint64_t>({1, 0}, 4);
  137. t->SetItemAt<uint64_t>({1, 1}, 5);
  138. t->SetItemAt<uint64_t>({1, 2}, 6);
  139. TensorTable tbl;
  140. TensorRow row;
  141. row.push_back(t);
  142. // Cache tensor row t 5000 times using 10 threads.
  143. for (auto k = 0; k < 10; ++k) {
  144. Status vg_rc = vg.CreateAsyncTask("Test agent", [&myClient, &row]() -> Status {
  145. TaskManager::FindMe()->Post();
  146. for (auto i = 0; i < 500; i++) {
  147. RETURN_IF_NOT_OK(myClient->WriteRow(row));
  148. }
  149. return Status::OK();
  150. });
  151. ASSERT_TRUE(vg_rc.IsOk());
  152. }
  153. ASSERT_TRUE(vg.join_all().IsOk());
  154. ASSERT_TRUE(vg.GetTaskErrorIfAny().IsOk());
  155. rc = myClient->BuildPhaseDone();
  156. ASSERT_TRUE(rc.IsOk());
  157. // Get statistics from the server.
  158. CacheServiceStat stat{};
  159. rc = myClient->GetStat(&stat);
  160. ASSERT_TRUE(rc.IsOk());
  161. std::cout << stat.min_row_id << ":" << stat.max_row_id << ":" << stat.num_mem_cached << ":" << stat.num_disk_cached
  162. << "\n";
  163. // Expect there are 5000 rows there.
  164. EXPECT_EQ(5000, stat.max_row_id - stat.min_row_id + 1);
  165. // Get them all back using row id and compare with tensor t.
  166. for (auto i = stat.min_row_id; i <= stat.max_row_id; ++i) {
  167. tbl.clear();
  168. row.clear();
  169. rc = myClient->GetRows({i}, &tbl);
  170. ASSERT_TRUE(rc.IsOk());
  171. row = tbl.front();
  172. auto r = row.front();
  173. bool cmp = (*t == *r);
  174. ASSERT_TRUE(cmp);
  175. }
  176. rc = myClient->DestroyCache();
  177. ASSERT_TRUE(rc.IsOk());
  178. }
  179. // Simple test with a repeated cache op over random data producer
  180. //
  181. // RepeatOp
  182. // |
  183. // CacheOp
  184. // |
  185. // RandomDataOp
  186. //
  187. TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) {
  188. Status rc;
  189. int32_t rank = 0; // not used
  190. MS_LOG(INFO) << "UT test TestRandomDataCache1";
  191. // Start with an empty execution tree
  192. auto myTree = std::make_shared<ExecutionTree>();
  193. // Create a schema using the C api's
  194. std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
  195. // 2 columns. First column is an "image" 640,480,3
  196. TensorShape c1Shape({640, 480, 3});
  197. ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
  198. rank, // not used
  199. &c1Shape);
  200. // Column 2 will just be a scalar label number
  201. TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
  202. ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
  203. testSchema->AddColumn(c1);
  204. testSchema->AddColumn(c2);
  205. // RandomDataOp
  206. std::shared_ptr<RandomDataOp> myRandomDataOp;
  207. rc = RandomDataOp::Builder()
  208. .SetRowsPerBuffer(4)
  209. .SetNumWorkers(4)
  210. .SetDataSchema(std::move(testSchema))
  211. .SetTotalRows(50) // 50 samples for now
  212. .Build(&myRandomDataOp);
  213. ASSERT_TRUE(rc.IsOk());
  214. rc = myTree->AssociateNode(myRandomDataOp);
  215. ASSERT_TRUE(rc.IsOk());
  216. // CacheOp
  217. // size of 0, spilling is true
  218. CacheClient::Builder builder;
  219. // use arbitrary session of 1, size of 0, spilling// is true
  220. builder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true);
  221. std::shared_ptr<CacheClient> myClient;
  222. rc = builder.Build(&myClient);
  223. ASSERT_TRUE(rc.IsOk());
  224. std::shared_ptr<CacheOp> myCacheOp;
  225. int64_t num_samples = 0;
  226. int64_t start_index = 0;
  227. auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
  228. rc = CacheOp::Builder()
  229. .SetNumWorkers(5)
  230. .SetClient(myClient)
  231. .SetRowsPerBuffer(4)
  232. .SetSampler(std::move(seq_sampler))
  233. .Build(&myCacheOp);
  234. ASSERT_TRUE(rc.IsOk());
  235. rc = myTree->AssociateNode(myCacheOp);
  236. ASSERT_TRUE(rc.IsOk());
  237. // RepeatOp
  238. uint32_t numRepeats = 4;
  239. std::shared_ptr<RepeatOp> myRepeatOp;
  240. rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
  241. ASSERT_TRUE(rc.IsOk());
  242. rc = myTree->AssociateNode(myRepeatOp);
  243. ASSERT_TRUE(rc.IsOk());
  244. // Assign tree relations and root
  245. rc = myRepeatOp->AddChild(myCacheOp);
  246. ASSERT_TRUE(rc.IsOk());
  247. rc = myCacheOp->AddChild(myRandomDataOp);
  248. ASSERT_TRUE(rc.IsOk());
  249. rc = myTree->AssignRoot(myRepeatOp);
  250. ASSERT_TRUE(rc.IsOk());
  251. MS_LOG(INFO) << "Launching tree and begin iteration";
  252. rc = myTree->Prepare();
  253. ASSERT_TRUE(rc.IsOk());
  254. // quick check to see what tree looks like
  255. std::ostringstream ss;
  256. ss << *myTree; // some funny const error if I try to write directly to ms log stream
  257. MS_LOG(INFO) << "Here's the tree:\n" << ss.str();
  258. std::cout << *myClient << std::endl;
  259. rc = myTree->Launch();
  260. ASSERT_TRUE(rc.IsOk());
  261. // Start the loop of reading tensors from our pipeline
  262. DatasetIterator dI(myTree);
  263. TensorRow tensorList;
  264. rc = dI.FetchNextTensorRow(&tensorList);
  265. ASSERT_TRUE(rc.IsOk());
  266. int rowCount = 0;
  267. while (!tensorList.empty()) {
  268. // Don't display these rows, just count them
  269. MS_LOG(INFO) << "Row fetched #: " << rowCount;
  270. rc = dI.FetchNextTensorRow(&tensorList);
  271. ASSERT_TRUE(rc.IsOk());
  272. rowCount++;
  273. }
  274. ASSERT_EQ(rowCount, 200);
  275. rc = myClient->DestroyCache();
  276. ASSERT_TRUE(rc.IsOk());
  277. }
  278. //// Simple test with a repeated cache op over random data producer.
  279. //// This one will exceed memory and require a spill.
  280. ////
  281. //// RepeatOp
  282. //// |
  283. //// CacheOp
  284. //// |
  285. //// RandomDataOp
  286. ////
  287. TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) {
  288. Status rc;
  289. int32_t rank = 0; // not used
  290. MS_LOG(INFO) << "UT test TestRandomDataCacheSpill";
  291. // Start with an empty execution tree
  292. auto myTree = std::make_shared<ExecutionTree>();
  293. // Create a schema using the C api's
  294. std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
  295. // 2 columns. First column is an "image" 640,480,3
  296. TensorShape c1Shape({640, 480, 3});
  297. ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
  298. rank, // not used
  299. &c1Shape);
  300. // Column 2 will just be a scalar label number
  301. TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
  302. ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
  303. testSchema->AddColumn(c1);
  304. testSchema->AddColumn(c2);
  305. // RandomDataOp
  306. std::shared_ptr<RandomDataOp> myRandomDataOp;
  307. rc = RandomDataOp::Builder()
  308. .SetRowsPerBuffer(2)
  309. .SetNumWorkers(4)
  310. .SetDataSchema(std::move(testSchema))
  311. .SetTotalRows(10)
  312. .Build(&myRandomDataOp);
  313. ASSERT_TRUE(rc.IsOk());
  314. rc = myTree->AssociateNode(myRandomDataOp);
  315. ASSERT_TRUE(rc.IsOk());
  316. // CacheOp
  317. int64_t num_samples = 0;
  318. int64_t start_index = 0;
  319. auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
  320. CacheClient::Builder builder;
  321. // use arbitrary session of 1, size of 0, spilling// is true
  322. builder.SetSessionId(1).SetCacheMemSz(4).SetSpill(true);
  323. std::shared_ptr<CacheClient> myClient;
  324. rc = builder.Build(&myClient);
  325. ASSERT_TRUE(rc.IsOk());
  326. std::shared_ptr<CacheOp> myCacheOp;
  327. rc = CacheOp::Builder()
  328. .SetNumWorkers(4)
  329. .SetClient(myClient)
  330. .SetRowsPerBuffer(3)
  331. .SetSampler(std::move(seq_sampler))
  332. .Build(&myCacheOp);
  333. ASSERT_TRUE(rc.IsOk());
  334. rc = myTree->AssociateNode(myCacheOp);
  335. ASSERT_TRUE(rc.IsOk());
  336. // RepeatOp
  337. uint32_t numRepeats = 4;
  338. std::shared_ptr<RepeatOp> myRepeatOp;
  339. rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
  340. ASSERT_TRUE(rc.IsOk());
  341. rc = myTree->AssociateNode(myRepeatOp);
  342. ASSERT_TRUE(rc.IsOk());
  343. // Assign tree relations and root
  344. rc = myRepeatOp->AddChild(myCacheOp);
  345. ASSERT_TRUE(rc.IsOk());
  346. rc = myCacheOp->AddChild(myRandomDataOp);
  347. ASSERT_TRUE(rc.IsOk());
  348. rc = myTree->AssignRoot(myRepeatOp);
  349. ASSERT_TRUE(rc.IsOk());
  350. MS_LOG(INFO) << "Launching tree and begin iteration";
  351. rc = myTree->Prepare();
  352. ASSERT_TRUE(rc.IsOk());
  353. std::cout << *myClient << std::endl;
  354. rc = myTree->Launch();
  355. ASSERT_TRUE(rc.IsOk());
  356. // Start the loop of reading tensors from our pipeline
  357. DatasetIterator dI(myTree);
  358. TensorRow tensorList;
  359. rc = dI.FetchNextTensorRow(&tensorList);
  360. ASSERT_TRUE(rc.IsOk());
  361. int rowCount = 0;
  362. while (!tensorList.empty()) {
  363. // Don't display these rows, just count them
  364. MS_LOG(INFO) << "Row fetched #: " << rowCount;
  365. rc = dI.FetchNextTensorRow(&tensorList);
  366. ASSERT_TRUE(rc.IsOk());
  367. rowCount++;
  368. }
  369. ASSERT_EQ(rowCount, 40);
  370. rc = myClient->DestroyCache();
  371. ASSERT_TRUE(rc.IsOk());
  372. }
  373. TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
  374. Status rc;
  375. int64_t num_samples = 0;
  376. int64_t start_index = 0;
  377. auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
  378. CacheClient::Builder ccbuilder;
  379. // use arbitrary session of 1, size of 0, spilling// is true
  380. ccbuilder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true);
  381. std::shared_ptr<CacheClient> myClient;
  382. rc = ccbuilder.Build(&myClient);
  383. ASSERT_TRUE(rc.IsOk());
  384. // In a mappable dataset, it uses a complex interactions of cache lookup op and cache merge op.
  385. // Rather than manually build this, the way to do it is to choose the position of the cache in the tree by
  386. // adding a CacheOp. Then, the tree prepare code will drive a transform that will remove the CacheOp and
  387. // replace it with the required tree structures for cache lookup op and cache merge op.
  388. std::shared_ptr<CacheOp> myCacheOp;
  389. rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp);
  390. std::shared_ptr<ImageFolderOp> so;
  391. ImageFolderOp::Builder builder;
  392. builder.SetSampler(std::move(seq_sampler))
  393. .SetOpConnectorSize(3)
  394. .SetNumWorkers(3)
  395. .SetRowsPerBuffer(2)
  396. .SetExtensions({".jpg", ".JPEG"})
  397. .SetRecursive(true)
  398. .SetImageFolderDir(datasets_root_path_ + "/testPK/data");
  399. rc = builder.Build(&so);
  400. ASSERT_TRUE(rc.IsOk());
  401. // RepeatOp
  402. uint32_t numRepeats = 4;
  403. std::shared_ptr<RepeatOp> myRepeatOp;
  404. rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
  405. ASSERT_TRUE(rc.IsOk());
  406. auto myTree = std::make_shared<ExecutionTree>();
  407. rc = myTree->AssociateNode(so);
  408. ASSERT_TRUE(rc.IsOk());
  409. rc = myTree->AssociateNode(myCacheOp);
  410. ASSERT_TRUE(rc.IsOk());
  411. rc = myTree->AssociateNode(myRepeatOp);
  412. ASSERT_TRUE(rc.IsOk());
  413. rc = myTree->AssignRoot(myRepeatOp);
  414. ASSERT_TRUE(rc.IsOk());
  415. rc = myRepeatOp->AddChild(myCacheOp);
  416. ASSERT_TRUE(rc.IsOk());
  417. rc = myCacheOp->AddChild(so);
  418. ASSERT_TRUE(rc.IsOk());
  419. rc = myTree->Prepare();
  420. ASSERT_TRUE(rc.IsOk());
  421. rc = myTree->Launch();
  422. ASSERT_TRUE(rc.IsOk());
  423. // Start the loop of reading tensors from our pipeline
  424. DatasetIterator dI(myTree);
  425. TensorRow tensorList;
  426. rc = dI.FetchNextTensorRow(&tensorList);
  427. ASSERT_TRUE(rc.IsOk());
  428. int rowCount = 0;
  429. while (!tensorList.empty()) {
  430. rc = dI.FetchNextTensorRow(&tensorList);
  431. ASSERT_TRUE(rc.IsOk());
  432. if (rc.IsError()) {
  433. std::cout << rc << std::endl;
  434. break;
  435. }
  436. rowCount++;
  437. }
  438. ASSERT_EQ(rowCount, 176);
  439. std::cout << "Row count : " << rowCount << std::endl;
  440. rc = myClient->DestroyCache();
  441. ASSERT_TRUE(rc.IsOk());
  442. }
  443. //// Simple test with a repeated cache op over random data producer.
  444. //// The difference in this one is that you do not add the sampler to the cache op directly.
  445. //// Instead, the sampler is added as part of the leaf op construction. Then, the prepare
  446. //// phase will pull this up from the leaf and into the cache.
  447. //// It removes the sampler from the leaf op, which doesn't make sense there anyway for
  448. //// the RandomDataOp which doesn't support sampling without a cache.
  449. ////
  450. //// RepeatOp
  451. //// |
  452. //// CacheOp
  453. //// |
  454. //// RandomDataOp
  455. ////
  456. TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) {
  457. Status rc;
  458. int32_t rank = 0; // not used
  459. MS_LOG(INFO) << "UT test TestCacheInheritSampler";
  460. int64_t num_samples = 0;
  461. int64_t start_index = 0;
  462. auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
  463. // Start with an empty execution tree
  464. auto myTree = std::make_shared<ExecutionTree>();
  465. // Create a schema using the C api's
  466. std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
  467. // 2 columns. First column is an "image" 640,480,3
  468. TensorShape c1Shape({640, 480, 3});
  469. ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
  470. rank, // not used
  471. &c1Shape);
  472. // Column 2 will just be a scalar label number
  473. TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
  474. ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
  475. testSchema->AddColumn(c1);
  476. testSchema->AddColumn(c2);
  477. // RandomDataOp
  478. std::shared_ptr<RandomDataOp> myRandomDataOp;
  479. rc = RandomDataOp::Builder()
  480. .SetRowsPerBuffer(2)
  481. .SetNumWorkers(4)
  482. .SetDataSchema(std::move(testSchema))
  483. .SetTotalRows(10)
  484. .SetSampler(std::move(seq_sampler))
  485. .Build(&myRandomDataOp);
  486. ASSERT_TRUE(rc.IsOk());
  487. rc = myTree->AssociateNode(myRandomDataOp);
  488. ASSERT_TRUE(rc.IsOk());
  489. // CacheOp
  490. CacheClient::Builder ccbuilder;
  491. // use arbitrary session of 1, size of 0, spilling// is true
  492. ccbuilder.SetSessionId(1).SetCacheMemSz(4).SetSpill(true);
  493. std::shared_ptr<CacheClient> myClient;
  494. rc = ccbuilder.Build(&myClient);
  495. ASSERT_TRUE(rc.IsOk());
  496. std::shared_ptr<CacheOp> myCacheOp;
  497. rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp);
  498. ASSERT_TRUE(rc.IsOk());
  499. rc = myTree->AssociateNode(myCacheOp);
  500. ASSERT_TRUE(rc.IsOk());
  501. // RepeatOp
  502. uint32_t numRepeats = 4;
  503. std::shared_ptr<RepeatOp> myRepeatOp;
  504. rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
  505. ASSERT_TRUE(rc.IsOk());
  506. rc = myTree->AssociateNode(myRepeatOp);
  507. ASSERT_TRUE(rc.IsOk());
  508. // Assign tree relations and root
  509. rc = myRepeatOp->AddChild(myCacheOp);
  510. ASSERT_TRUE(rc.IsOk());
  511. rc = myCacheOp->AddChild(myRandomDataOp);
  512. ASSERT_TRUE(rc.IsOk());
  513. rc = myTree->AssignRoot(myRepeatOp);
  514. ASSERT_TRUE(rc.IsOk());
  515. MS_LOG(INFO) << "Launching tree and begin iteration";
  516. rc = myTree->Prepare();
  517. ASSERT_TRUE(rc.IsOk());
  518. std::cout << *myClient << std::endl;
  519. rc = myTree->Launch();
  520. ASSERT_TRUE(rc.IsOk());
  521. // Start the loop of reading tensors from our pipeline
  522. DatasetIterator dI(myTree);
  523. TensorRow tensorList;
  524. rc = dI.FetchNextTensorRow(&tensorList);
  525. ASSERT_TRUE(rc.IsOk());
  526. int rowCount = 0;
  527. while (!tensorList.empty()) {
  528. // Don't display these rows, just count them
  529. MS_LOG(INFO) << "Row fetched #: " << rowCount;
  530. rc = dI.FetchNextTensorRow(&tensorList);
  531. ASSERT_TRUE(rc.IsOk());
  532. rowCount++;
  533. }
  534. ASSERT_EQ(rowCount, 40);
  535. rc = myClient->DestroyCache();
  536. ASSERT_TRUE(rc.IsOk());
  537. }