You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cache_op_test.cc 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <string>
  17. #include "minddata/dataset/core/client.h"
  18. #include "minddata/dataset/engine/cache/cache_client.h"
  19. #include "minddata/dataset/engine/execution_tree.h"
  20. #include "minddata/dataset/engine/datasetops/cache_op.h"
  21. #include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
  22. #include "minddata/dataset/engine/datasetops/cache_merge_op.h"
  23. #include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
  24. #include "common/common.h"
  25. #include "gtest/gtest.h"
  26. #include "utils/log_adapter.h"
  27. #include "minddata/dataset/util/storage_container.h" // lint !e322
  28. #include "minddata/dataset/engine/datasetops/source/random_data_op.h"
  29. #include "minddata/dataset/engine/data_schema.h"
  30. using namespace mindspore::dataset;
  31. using mindspore::LogStream;
  32. using mindspore::dataset::CacheClient;
  33. using mindspore::dataset::TaskGroup;
  34. using mindspore::ExceptionType::NoExceptionType;
  35. using mindspore::MsLogLevel::INFO;
  36. // Helper function to get the session id from SESSION_ID env variable
  37. Status GetSessionFromEnv(session_id_type *session_id) {
  38. RETURN_UNEXPECTED_IF_NULL(session_id);
  39. if (const char *session_env = std::getenv("SESSION_ID")) {
  40. std::string session_id_str(session_env);
  41. try {
  42. *session_id = std::stoul(session_id_str);
  43. } catch (const std::exception &e) {
  44. std::string err_msg = "Invalid numeric value for session id in env var: " + session_id_str;
  45. return Status(StatusCode::kSyntaxError, err_msg);
  46. }
  47. } else {
  48. RETURN_STATUS_UNEXPECTED("Test case requires a session id to be provided via SESSION_ID environment variable.");
  49. }
  50. return Status::OK();
  51. }
  52. class MindDataTestCacheOp : public UT::DatasetOpTesting {
  53. public:
  54. void SetUp() override {
  55. DatasetOpTesting::SetUp();
  56. GlobalInit();
  57. }
  58. };
  59. TEST_F(MindDataTestCacheOp, DISABLED_TestCacheServer) {
  60. Status rc;
  61. CacheClient::Builder builder;
  62. session_id_type env_session;
  63. rc = GetSessionFromEnv(&env_session);
  64. ASSERT_TRUE(rc.IsOk());
  65. // use arbitrary session of 1, size of 0, spilling// is true
  66. builder.SetSessionId(env_session).SetCacheMemSz(0).SetSpill(true);
  67. std::shared_ptr<CacheClient> myClient;
  68. rc = builder.Build(&myClient);
  69. ASSERT_TRUE(rc.IsOk());
  70. // cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated.
  71. rc = myClient->CreateCache(1, true);
  72. ASSERT_TRUE(rc.IsOk());
  73. std::cout << *myClient << std::endl;
  74. // Create a schema using the C api's
  75. int32_t rank = 0; // not used
  76. std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
  77. // 2 columns. First column is an "image" 640,480,3
  78. TensorShape c1Shape({640, 480, 3});
  79. ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
  80. rank, // not used
  81. &c1Shape);
  82. // Column 2 will just be a scalar label number
  83. TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
  84. ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
  85. testSchema->AddColumn(c1);
  86. testSchema->AddColumn(c2);
  87. std::unordered_map<std::string, int32_t> map;
  88. rc = testSchema->GetColumnNameMap(&map);
  89. ASSERT_TRUE(rc.IsOk());
  90. // Test the CacheSchema api
  91. rc = myClient->CacheSchema(map);
  92. ASSERT_TRUE(rc.IsOk());
  93. // Create a tensor, take a snapshot and restore it back, and compare.
  94. std::shared_ptr<Tensor> t;
  95. Tensor::CreateEmpty(TensorShape({2, 3}), DataType(DataType::DE_UINT64), &t);
  96. t->SetItemAt<uint64_t>({0, 0}, 1);
  97. t->SetItemAt<uint64_t>({0, 1}, 2);
  98. t->SetItemAt<uint64_t>({0, 2}, 3);
  99. t->SetItemAt<uint64_t>({1, 0}, 4);
  100. t->SetItemAt<uint64_t>({1, 1}, 5);
  101. t->SetItemAt<uint64_t>({1, 2}, 6);
  102. std::cout << *t << std::endl;
  103. TensorTable tbl;
  104. TensorRow row;
  105. row.push_back(t);
  106. int64_t row_id;
  107. rc = myClient->WriteRow(row, &row_id);
  108. ASSERT_TRUE(rc.IsOk());
  109. // Switch off build phase.
  110. rc = myClient->BuildPhaseDone();
  111. ASSERT_TRUE(rc.IsOk());
  112. // Now restore from cache.
  113. row.clear();
  114. rc = myClient->GetRows({row_id}, &tbl);
  115. row = tbl.front();
  116. ASSERT_TRUE(rc.IsOk());
  117. auto r = row.front();
  118. std::cout << *r << std::endl;
  119. // Compare
  120. bool cmp = (*t == *r);
  121. ASSERT_TRUE(cmp);
  122. // Get back the schema and verify
  123. std::unordered_map<std::string, int32_t> map_out;
  124. rc = myClient->FetchSchema(&map_out);
  125. ASSERT_TRUE(rc.IsOk());
  126. cmp = (map_out == map);
  127. ASSERT_TRUE(cmp);
  128. rc = myClient->DestroyCache();
  129. ASSERT_TRUE(rc.IsOk());
  130. }
  131. TEST_F(MindDataTestCacheOp, DISABLED_TestConcurrencyRequest) {
  132. // Clear the rc of the master thread if any
  133. (void)TaskManager::GetMasterThreadRc();
  134. TaskGroup vg;
  135. Status rc;
  136. session_id_type env_session;
  137. rc = GetSessionFromEnv(&env_session);
  138. ASSERT_TRUE(rc.IsOk());
  139. // use arbitrary session of 1, size 1, spilling is true
  140. CacheClient::Builder builder;
  141. // use arbitrary session of 1, size of 0, spilling// is true
  142. builder.SetSessionId(env_session).SetCacheMemSz(1).SetSpill(true);
  143. std::shared_ptr<CacheClient> myClient;
  144. rc = builder.Build(&myClient);
  145. ASSERT_TRUE(rc.IsOk());
  146. // cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated.
  147. rc = myClient->CreateCache(1, true);
  148. ASSERT_TRUE(rc.IsOk());
  149. std::cout << *myClient << std::endl;
  150. std::shared_ptr<Tensor> t;
  151. Tensor::CreateEmpty(TensorShape({2, 3}), DataType(DataType::DE_UINT64), &t);
  152. t->SetItemAt<uint64_t>({0, 0}, 1);
  153. t->SetItemAt<uint64_t>({0, 1}, 2);
  154. t->SetItemAt<uint64_t>({0, 2}, 3);
  155. t->SetItemAt<uint64_t>({1, 0}, 4);
  156. t->SetItemAt<uint64_t>({1, 1}, 5);
  157. t->SetItemAt<uint64_t>({1, 2}, 6);
  158. TensorTable tbl;
  159. TensorRow row;
  160. row.push_back(t);
  161. // Cache tensor row t 5000 times using 10 threads.
  162. for (auto k = 0; k < 10; ++k) {
  163. Status vg_rc = vg.CreateAsyncTask("Test agent", [&myClient, &row]() -> Status {
  164. TaskManager::FindMe()->Post();
  165. for (auto i = 0; i < 500; i++) {
  166. RETURN_IF_NOT_OK(myClient->WriteRow(row));
  167. }
  168. return Status::OK();
  169. });
  170. ASSERT_TRUE(vg_rc.IsOk());
  171. }
  172. ASSERT_TRUE(vg.join_all().IsOk());
  173. ASSERT_TRUE(vg.GetTaskErrorIfAny().IsOk());
  174. rc = myClient->BuildPhaseDone();
  175. ASSERT_TRUE(rc.IsOk());
  176. // Get statistics from the server.
  177. CacheServiceStat stat{};
  178. rc = myClient->GetStat(&stat);
  179. ASSERT_TRUE(rc.IsOk());
  180. std::cout << stat.min_row_id << ":" << stat.max_row_id << ":" << stat.num_mem_cached << ":" << stat.num_disk_cached
  181. << "\n";
  182. // Expect there are 5000 rows there.
  183. EXPECT_EQ(5000, stat.max_row_id - stat.min_row_id + 1);
  184. // Get them all back using row id and compare with tensor t.
  185. for (auto i = stat.min_row_id; i <= stat.max_row_id; ++i) {
  186. tbl.clear();
  187. row.clear();
  188. rc = myClient->GetRows({i}, &tbl);
  189. ASSERT_TRUE(rc.IsOk());
  190. row = tbl.front();
  191. auto r = row.front();
  192. bool cmp = (*t == *r);
  193. ASSERT_TRUE(cmp);
  194. }
  195. rc = myClient->DestroyCache();
  196. ASSERT_TRUE(rc.IsOk());
  197. }
  198. // Simple test with a repeated cache op over random data producer
  199. //
  200. // RepeatOp
  201. // |
  202. // CacheOp
  203. // |
  204. // RandomDataOp
  205. //
  206. TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) {
  207. // Clear the rc of the master thread if any
  208. (void)TaskManager::GetMasterThreadRc();
  209. Status rc;
  210. int32_t rank = 0; // not used
  211. session_id_type env_session;
  212. rc = GetSessionFromEnv(&env_session);
  213. ASSERT_TRUE(rc.IsOk());
  214. MS_LOG(INFO) << "UT test TestRandomDataCache1";
  215. // Start with an empty execution tree
  216. auto myTree = std::make_shared<ExecutionTree>();
  217. // Create a schema using the C api's
  218. std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
  219. // 2 columns. First column is an "image" 640,480,3
  220. TensorShape c1Shape({640, 480, 3});
  221. ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
  222. rank, // not used
  223. &c1Shape);
  224. // Column 2 will just be a scalar label number
  225. TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
  226. ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
  227. testSchema->AddColumn(c1);
  228. testSchema->AddColumn(c2);
  229. // RandomDataOp
  230. std::shared_ptr<RandomDataOp> myRandomDataOp;
  231. rc = RandomDataOp::Builder()
  232. .SetRowsPerBuffer(4)
  233. .SetNumWorkers(4)
  234. .SetDataSchema(std::move(testSchema))
  235. .SetTotalRows(50) // 50 samples for now
  236. .Build(&myRandomDataOp);
  237. ASSERT_TRUE(rc.IsOk());
  238. rc = myTree->AssociateNode(myRandomDataOp);
  239. ASSERT_TRUE(rc.IsOk());
  240. // CacheOp
  241. // size of 0, spilling is true
  242. CacheClient::Builder builder;
  243. builder.SetSessionId(env_session).SetCacheMemSz(0).SetSpill(true);
  244. std::shared_ptr<CacheClient> myClient;
  245. rc = builder.Build(&myClient);
  246. ASSERT_TRUE(rc.IsOk());
  247. std::shared_ptr<CacheOp> myCacheOp;
  248. int64_t num_samples = 0;
  249. int64_t start_index = 0;
  250. auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
  251. rc = CacheOp::Builder()
  252. .SetNumWorkers(5)
  253. .SetClient(myClient)
  254. .SetRowsPerBuffer(4)
  255. .SetSampler(std::move(seq_sampler))
  256. .Build(&myCacheOp);
  257. ASSERT_TRUE(rc.IsOk());
  258. rc = myTree->AssociateNode(myCacheOp);
  259. ASSERT_TRUE(rc.IsOk());
  260. // RepeatOp
  261. uint32_t numRepeats = 4;
  262. std::shared_ptr<RepeatOp> myRepeatOp;
  263. rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
  264. ASSERT_TRUE(rc.IsOk());
  265. rc = myTree->AssociateNode(myRepeatOp);
  266. ASSERT_TRUE(rc.IsOk());
  267. // Assign tree relations and root
  268. rc = myRepeatOp->AddChild(myCacheOp);
  269. ASSERT_TRUE(rc.IsOk());
  270. rc = myCacheOp->AddChild(myRandomDataOp);
  271. ASSERT_TRUE(rc.IsOk());
  272. rc = myTree->AssignRoot(myRepeatOp);
  273. ASSERT_TRUE(rc.IsOk());
  274. MS_LOG(INFO) << "Launching tree and begin iteration";
  275. rc = myTree->Prepare(1);
  276. ASSERT_TRUE(rc.IsOk());
  277. // quick check to see what tree looks like
  278. std::ostringstream ss;
  279. ss << *myTree; // some funny const error if I try to write directly to ms log stream
  280. MS_LOG(INFO) << "Here's the tree:\n" << ss.str();
  281. std::cout << *myClient << std::endl;
  282. rc = myTree->Launch();
  283. ASSERT_TRUE(rc.IsOk());
  284. // Start the loop of reading tensors from our pipeline
  285. DatasetIterator dI(myTree);
  286. TensorRow tensorList;
  287. rc = dI.FetchNextTensorRow(&tensorList);
  288. ASSERT_TRUE(rc.IsOk());
  289. int rowCount = 0;
  290. while (!tensorList.empty()) {
  291. // Don't display these rows, just count them
  292. MS_LOG(INFO) << "Row fetched #: " << rowCount;
  293. rc = dI.FetchNextTensorRow(&tensorList);
  294. ASSERT_TRUE(rc.IsOk());
  295. rowCount++;
  296. }
  297. ASSERT_EQ(rowCount, 200);
  298. rc = myClient->DestroyCache();
  299. ASSERT_TRUE(rc.IsOk());
  300. }
  301. //// Simple test with a repeated cache op over random data producer.
  302. //// This one will exceed memory and require a spill.
  303. ////
  304. //// RepeatOp
  305. //// |
  306. //// CacheOp
  307. //// |
  308. //// RandomDataOp
  309. ////
  310. TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) {
  311. // Clear the rc of the master thread if any
  312. (void)TaskManager::GetMasterThreadRc();
  313. Status rc;
  314. int32_t rank = 0; // not used
  315. MS_LOG(INFO) << "UT test TestRandomDataCacheSpill";
  316. session_id_type env_session;
  317. rc = GetSessionFromEnv(&env_session);
  318. ASSERT_TRUE(rc.IsOk());
  319. // Start with an empty execution tree
  320. auto myTree = std::make_shared<ExecutionTree>();
  321. // Create a schema using the C api's
  322. std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
  323. // 2 columns. First column is an "image" 640,480,3
  324. TensorShape c1Shape({640, 480, 3});
  325. ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
  326. rank, // not used
  327. &c1Shape);
  328. // Column 2 will just be a scalar label number
  329. TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
  330. ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
  331. testSchema->AddColumn(c1);
  332. testSchema->AddColumn(c2);
  333. // RandomDataOp
  334. std::shared_ptr<RandomDataOp> myRandomDataOp;
  335. rc = RandomDataOp::Builder()
  336. .SetRowsPerBuffer(2)
  337. .SetNumWorkers(4)
  338. .SetDataSchema(std::move(testSchema))
  339. .SetTotalRows(10)
  340. .Build(&myRandomDataOp);
  341. ASSERT_TRUE(rc.IsOk());
  342. rc = myTree->AssociateNode(myRandomDataOp);
  343. ASSERT_TRUE(rc.IsOk());
  344. // CacheOp
  345. int64_t num_samples = 0;
  346. int64_t start_index = 0;
  347. auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
  348. CacheClient::Builder builder;
  349. builder.SetSessionId(env_session).SetCacheMemSz(4).SetSpill(true);
  350. std::shared_ptr<CacheClient> myClient;
  351. rc = builder.Build(&myClient);
  352. ASSERT_TRUE(rc.IsOk());
  353. std::shared_ptr<CacheOp> myCacheOp;
  354. rc = CacheOp::Builder()
  355. .SetNumWorkers(4)
  356. .SetClient(myClient)
  357. .SetRowsPerBuffer(3)
  358. .SetSampler(std::move(seq_sampler))
  359. .Build(&myCacheOp);
  360. ASSERT_TRUE(rc.IsOk());
  361. rc = myTree->AssociateNode(myCacheOp);
  362. ASSERT_TRUE(rc.IsOk());
  363. // RepeatOp
  364. uint32_t numRepeats = 4;
  365. std::shared_ptr<RepeatOp> myRepeatOp;
  366. rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
  367. ASSERT_TRUE(rc.IsOk());
  368. rc = myTree->AssociateNode(myRepeatOp);
  369. ASSERT_TRUE(rc.IsOk());
  370. // Assign tree relations and root
  371. rc = myRepeatOp->AddChild(myCacheOp);
  372. ASSERT_TRUE(rc.IsOk());
  373. rc = myCacheOp->AddChild(myRandomDataOp);
  374. ASSERT_TRUE(rc.IsOk());
  375. rc = myTree->AssignRoot(myRepeatOp);
  376. ASSERT_TRUE(rc.IsOk());
  377. MS_LOG(INFO) << "Launching tree and begin iteration";
  378. rc = myTree->Prepare(1);
  379. ASSERT_TRUE(rc.IsOk());
  380. std::cout << *myClient << std::endl;
  381. rc = myTree->Launch();
  382. ASSERT_TRUE(rc.IsOk());
  383. // Start the loop of reading tensors from our pipeline
  384. DatasetIterator dI(myTree);
  385. TensorRow tensorList;
  386. rc = dI.FetchNextTensorRow(&tensorList);
  387. ASSERT_TRUE(rc.IsOk());
  388. int rowCount = 0;
  389. while (!tensorList.empty()) {
  390. // Don't display these rows, just count them
  391. MS_LOG(INFO) << "Row fetched #: " << rowCount;
  392. rc = dI.FetchNextTensorRow(&tensorList);
  393. ASSERT_TRUE(rc.IsOk());
  394. rowCount++;
  395. }
  396. ASSERT_EQ(rowCount, 40);
  397. rc = myClient->DestroyCache();
  398. ASSERT_TRUE(rc.IsOk());
  399. }
  400. TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
  401. // Clear the rc of the master thread if any
  402. (void)TaskManager::GetMasterThreadRc();
  403. Status rc;
  404. int64_t num_samples = 0;
  405. int64_t start_index = 0;
  406. session_id_type env_session;
  407. rc = GetSessionFromEnv(&env_session);
  408. ASSERT_TRUE(rc.IsOk());
  409. auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
  410. CacheClient::Builder ccbuilder;
  411. ccbuilder.SetSessionId(env_session).SetCacheMemSz(0).SetSpill(true);
  412. std::shared_ptr<CacheClient> myClient;
  413. rc = ccbuilder.Build(&myClient);
  414. ASSERT_TRUE(rc.IsOk());
  415. // In a mappable dataset, it uses a complex interactions of cache lookup op and cache merge op.
  416. // Rather than manually build this, the way to do it is to choose the position of the cache in the tree by
  417. // adding a CacheOp. Then, the tree prepare code will drive a transform that will remove the CacheOp and
  418. // replace it with the required tree structures for cache lookup op and cache merge op.
  419. std::shared_ptr<CacheOp> myCacheOp;
  420. rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp);
  421. std::shared_ptr<ImageFolderOp> so;
  422. ImageFolderOp::Builder builder;
  423. builder.SetSampler(std::move(seq_sampler))
  424. .SetOpConnectorSize(3)
  425. .SetNumWorkers(3)
  426. .SetRowsPerBuffer(2)
  427. .SetExtensions({".jpg", ".JPEG"})
  428. .SetRecursive(true)
  429. .SetImageFolderDir(datasets_root_path_ + "/testPK/data");
  430. rc = builder.Build(&so);
  431. ASSERT_TRUE(rc.IsOk());
  432. // RepeatOp
  433. uint32_t numRepeats = 4;
  434. std::shared_ptr<RepeatOp> myRepeatOp;
  435. rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
  436. ASSERT_TRUE(rc.IsOk());
  437. auto myTree = std::make_shared<ExecutionTree>();
  438. rc = myTree->AssociateNode(so);
  439. ASSERT_TRUE(rc.IsOk());
  440. rc = myTree->AssociateNode(myCacheOp);
  441. ASSERT_TRUE(rc.IsOk());
  442. rc = myTree->AssociateNode(myRepeatOp);
  443. ASSERT_TRUE(rc.IsOk());
  444. rc = myTree->AssignRoot(myRepeatOp);
  445. ASSERT_TRUE(rc.IsOk());
  446. rc = myRepeatOp->AddChild(myCacheOp);
  447. ASSERT_TRUE(rc.IsOk());
  448. rc = myCacheOp->AddChild(so);
  449. ASSERT_TRUE(rc.IsOk());
  450. rc = myTree->Prepare(1);
  451. ASSERT_TRUE(rc.IsOk());
  452. rc = myTree->Launch();
  453. ASSERT_TRUE(rc.IsOk());
  454. // Start the loop of reading tensors from our pipeline
  455. DatasetIterator dI(myTree);
  456. TensorRow tensorList;
  457. rc = dI.FetchNextTensorRow(&tensorList);
  458. ASSERT_TRUE(rc.IsOk());
  459. int rowCount = 0;
  460. while (!tensorList.empty()) {
  461. rc = dI.FetchNextTensorRow(&tensorList);
  462. ASSERT_TRUE(rc.IsOk());
  463. if (rc.IsError()) {
  464. std::cout << rc << std::endl;
  465. break;
  466. }
  467. rowCount++;
  468. }
  469. ASSERT_EQ(rowCount, 176);
  470. std::cout << "Row count : " << rowCount << std::endl;
  471. rc = myClient->DestroyCache();
  472. ASSERT_TRUE(rc.IsOk());
  473. }
  474. //// Simple test with a repeated cache op over random data producer.
  475. //// The difference in this one is that you do not add the sampler to the cache op directly.
  476. //// Instead, the sampler is added as part of the leaf op construction. Then, the prepare
  477. //// phase will pull this up from the leaf and into the cache.
  478. //// It removes the sampler from the leaf op, which doesn't make sense there anyway for
  479. //// the RandomDataOp which doesn't support sampling without a cache.
  480. ////
  481. //// RepeatOp
  482. //// |
  483. //// CacheOp
  484. //// |
  485. //// RandomDataOp
  486. ////
  487. TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) {
  488. // Clear the rc of the master thread if any
  489. (void)TaskManager::GetMasterThreadRc();
  490. Status rc;
  491. int32_t rank = 0; // not used
  492. MS_LOG(INFO) << "UT test TestCacheInheritSampler";
  493. session_id_type env_session;
  494. rc = GetSessionFromEnv(&env_session);
  495. ASSERT_TRUE(rc.IsOk());
  496. int64_t num_samples = 0;
  497. int64_t start_index = 0;
  498. auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
  499. // Start with an empty execution tree
  500. auto myTree = std::make_shared<ExecutionTree>();
  501. // Create a schema using the C api's
  502. std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
  503. // 2 columns. First column is an "image" 640,480,3
  504. TensorShape c1Shape({640, 480, 3});
  505. ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
  506. rank, // not used
  507. &c1Shape);
  508. // Column 2 will just be a scalar label number
  509. TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
  510. ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
  511. testSchema->AddColumn(c1);
  512. testSchema->AddColumn(c2);
  513. // RandomDataOp
  514. std::shared_ptr<RandomDataOp> myRandomDataOp;
  515. rc = RandomDataOp::Builder()
  516. .SetRowsPerBuffer(2)
  517. .SetNumWorkers(4)
  518. .SetDataSchema(std::move(testSchema))
  519. .SetTotalRows(10)
  520. .SetSampler(std::move(seq_sampler))
  521. .Build(&myRandomDataOp);
  522. ASSERT_TRUE(rc.IsOk());
  523. rc = myTree->AssociateNode(myRandomDataOp);
  524. ASSERT_TRUE(rc.IsOk());
  525. // CacheOp
  526. CacheClient::Builder ccbuilder;
  527. // use arbitrary session of 1, size of 0, spilling// is true
  528. ccbuilder.SetSessionId(env_session).SetCacheMemSz(4).SetSpill(true);
  529. std::shared_ptr<CacheClient> myClient;
  530. rc = ccbuilder.Build(&myClient);
  531. ASSERT_TRUE(rc.IsOk());
  532. std::shared_ptr<CacheOp> myCacheOp;
  533. rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp);
  534. ASSERT_TRUE(rc.IsOk());
  535. rc = myTree->AssociateNode(myCacheOp);
  536. ASSERT_TRUE(rc.IsOk());
  537. // RepeatOp
  538. uint32_t numRepeats = 4;
  539. std::shared_ptr<RepeatOp> myRepeatOp;
  540. rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
  541. ASSERT_TRUE(rc.IsOk());
  542. rc = myTree->AssociateNode(myRepeatOp);
  543. ASSERT_TRUE(rc.IsOk());
  544. // Assign tree relations and root
  545. rc = myRepeatOp->AddChild(myCacheOp);
  546. ASSERT_TRUE(rc.IsOk());
  547. rc = myCacheOp->AddChild(myRandomDataOp);
  548. ASSERT_TRUE(rc.IsOk());
  549. rc = myTree->AssignRoot(myRepeatOp);
  550. ASSERT_TRUE(rc.IsOk());
  551. MS_LOG(INFO) << "Launching tree and begin iteration";
  552. rc = myTree->Prepare(1);
  553. ASSERT_TRUE(rc.IsOk());
  554. std::cout << *myClient << std::endl;
  555. rc = myTree->Launch();
  556. ASSERT_TRUE(rc.IsOk());
  557. // Start the loop of reading tensors from our pipeline
  558. DatasetIterator dI(myTree);
  559. TensorRow tensorList;
  560. rc = dI.FetchNextTensorRow(&tensorList);
  561. ASSERT_TRUE(rc.IsOk());
  562. int rowCount = 0;
  563. while (!tensorList.empty()) {
  564. // Don't display these rows, just count them
  565. MS_LOG(INFO) << "Row fetched #: " << rowCount;
  566. rc = dI.FetchNextTensorRow(&tensorList);
  567. ASSERT_TRUE(rc.IsOk());
  568. rowCount++;
  569. }
  570. ASSERT_EQ(rowCount, 40);
  571. rc = myClient->DestroyCache();
  572. ASSERT_TRUE(rc.IsOk());
  573. }