You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cache_op_test.cc 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <string>
  17. #include "minddata/dataset/core/client.h"
  18. #include "minddata/dataset/engine/cache/cache_client.h"
  19. #include "minddata/dataset/engine/execution_tree.h"
  20. #include "minddata/dataset/engine/datasetops/cache_op.h"
  21. #include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
  22. #include "minddata/dataset/engine/datasetops/cache_merge_op.h"
  23. #include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
  24. #include "common/common.h"
  25. #include "gtest/gtest.h"
  26. #include "utils/log_adapter.h"
  27. #include "minddata/dataset/engine/datasetops/source/random_data_op.h"
  28. #include "minddata/dataset/engine/data_schema.h"
  29. using namespace mindspore::dataset;
  30. using mindspore::LogStream;
  31. using mindspore::dataset::CacheClient;
  32. using mindspore::dataset::TaskGroup;
  33. using mindspore::ExceptionType::NoExceptionType;
  34. using mindspore::MsLogLevel::INFO;
  35. // Helper function to get the session id from SESSION_ID env variable
  36. Status GetSessionFromEnv(session_id_type *session_id) {
  37. RETURN_UNEXPECTED_IF_NULL(session_id);
  38. if (const char *session_env = std::getenv("SESSION_ID")) {
  39. std::string session_id_str(session_env);
  40. try {
  41. *session_id = std::stoul(session_id_str);
  42. } catch (const std::exception &e) {
  43. std::string err_msg = "Invalid numeric value for session id in env var: " + session_id_str;
  44. return Status(StatusCode::kSyntaxError, err_msg);
  45. }
  46. } else {
  47. RETURN_STATUS_UNEXPECTED("Test case requires a session id to be provided via SESSION_ID environment variable.");
  48. }
  49. return Status::OK();
  50. }
  51. class MindDataTestCacheOp : public UT::DatasetOpTesting {
  52. public:
  53. void SetUp() override {
  54. DatasetOpTesting::SetUp();
  55. GlobalInit();
  56. }
  57. };
  58. TEST_F(MindDataTestCacheOp, DISABLED_TestCacheServer) {
  59. Status rc;
  60. CacheClient::Builder builder;
  61. session_id_type env_session;
  62. rc = GetSessionFromEnv(&env_session);
  63. ASSERT_TRUE(rc.IsOk());
  64. // use arbitrary session of 1, size of 0, spilling// is true
  65. builder.SetSessionId(env_session).SetCacheMemSz(0).SetSpill(true);
  66. std::shared_ptr<CacheClient> myClient;
  67. rc = builder.Build(&myClient);
  68. ASSERT_TRUE(rc.IsOk());
  69. // cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated.
  70. rc = myClient->CreateCache(1, true);
  71. ASSERT_TRUE(rc.IsOk());
  72. std::cout << *myClient << std::endl;
  73. // Create a schema using the C api's
  74. int32_t rank = 0; // not used
  75. std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
  76. // 2 columns. First column is an "image" 640,480,3
  77. TensorShape c1Shape({640, 480, 3});
  78. ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
  79. rank, // not used
  80. &c1Shape);
  81. // Column 2 will just be a scalar label number
  82. TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
  83. ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
  84. testSchema->AddColumn(c1);
  85. testSchema->AddColumn(c2);
  86. std::unordered_map<std::string, int32_t> map;
  87. rc = testSchema->GetColumnNameMap(&map);
  88. ASSERT_TRUE(rc.IsOk());
  89. // Test the CacheSchema api
  90. rc = myClient->CacheSchema(map);
  91. ASSERT_TRUE(rc.IsOk());
  92. // Create a tensor, take a snapshot and restore it back, and compare.
  93. std::shared_ptr<Tensor> t;
  94. Tensor::CreateEmpty(TensorShape({2, 3}), DataType(DataType::DE_UINT64), &t);
  95. t->SetItemAt<uint64_t>({0, 0}, 1);
  96. t->SetItemAt<uint64_t>({0, 1}, 2);
  97. t->SetItemAt<uint64_t>({0, 2}, 3);
  98. t->SetItemAt<uint64_t>({1, 0}, 4);
  99. t->SetItemAt<uint64_t>({1, 1}, 5);
  100. t->SetItemAt<uint64_t>({1, 2}, 6);
  101. std::cout << *t << std::endl;
  102. TensorTable tbl;
  103. TensorRow row;
  104. row.push_back(t);
  105. int64_t row_id;
  106. rc = myClient->WriteRow(row, &row_id);
  107. ASSERT_TRUE(rc.IsOk());
  108. // Switch off build phase.
  109. rc = myClient->BuildPhaseDone();
  110. ASSERT_TRUE(rc.IsOk());
  111. // Now restore from cache.
  112. row.clear();
  113. rc = myClient->GetRows({row_id}, &tbl);
  114. row = tbl.front();
  115. ASSERT_TRUE(rc.IsOk());
  116. auto r = row.front();
  117. std::cout << *r << std::endl;
  118. // Compare
  119. bool cmp = (*t == *r);
  120. ASSERT_TRUE(cmp);
  121. // Get back the schema and verify
  122. std::unordered_map<std::string, int32_t> map_out;
  123. rc = myClient->FetchSchema(&map_out);
  124. ASSERT_TRUE(rc.IsOk());
  125. cmp = (map_out == map);
  126. ASSERT_TRUE(cmp);
  127. rc = myClient->DestroyCache();
  128. ASSERT_TRUE(rc.IsOk());
  129. }
  130. TEST_F(MindDataTestCacheOp, DISABLED_TestConcurrencyRequest) {
  131. // Clear the rc of the master thread if any
  132. (void)TaskManager::GetMasterThreadRc();
  133. TaskGroup vg;
  134. Status rc;
  135. session_id_type env_session;
  136. rc = GetSessionFromEnv(&env_session);
  137. ASSERT_TRUE(rc.IsOk());
  138. // use arbitrary session of 1, size 1, spilling is true
  139. CacheClient::Builder builder;
  140. // use arbitrary session of 1, size of 0, spilling// is true
  141. builder.SetSessionId(env_session).SetCacheMemSz(1).SetSpill(true);
  142. std::shared_ptr<CacheClient> myClient;
  143. rc = builder.Build(&myClient);
  144. ASSERT_TRUE(rc.IsOk());
  145. // cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated.
  146. rc = myClient->CreateCache(1, true);
  147. ASSERT_TRUE(rc.IsOk());
  148. std::cout << *myClient << std::endl;
  149. std::shared_ptr<Tensor> t;
  150. Tensor::CreateEmpty(TensorShape({2, 3}), DataType(DataType::DE_UINT64), &t);
  151. t->SetItemAt<uint64_t>({0, 0}, 1);
  152. t->SetItemAt<uint64_t>({0, 1}, 2);
  153. t->SetItemAt<uint64_t>({0, 2}, 3);
  154. t->SetItemAt<uint64_t>({1, 0}, 4);
  155. t->SetItemAt<uint64_t>({1, 1}, 5);
  156. t->SetItemAt<uint64_t>({1, 2}, 6);
  157. TensorTable tbl;
  158. TensorRow row;
  159. row.push_back(t);
  160. // Cache tensor row t 5000 times using 10 threads.
  161. for (auto k = 0; k < 10; ++k) {
  162. Status vg_rc = vg.CreateAsyncTask("Test agent", [&myClient, &row]() -> Status {
  163. TaskManager::FindMe()->Post();
  164. for (auto i = 0; i < 500; i++) {
  165. RETURN_IF_NOT_OK(myClient->WriteRow(row));
  166. }
  167. return Status::OK();
  168. });
  169. ASSERT_TRUE(vg_rc.IsOk());
  170. }
  171. ASSERT_TRUE(vg.join_all().IsOk());
  172. ASSERT_TRUE(vg.GetTaskErrorIfAny().IsOk());
  173. rc = myClient->BuildPhaseDone();
  174. ASSERT_TRUE(rc.IsOk());
  175. // Get statistics from the server.
  176. CacheServiceStat stat{};
  177. rc = myClient->GetStat(&stat);
  178. ASSERT_TRUE(rc.IsOk());
  179. std::cout << stat.min_row_id << ":" << stat.max_row_id << ":" << stat.num_mem_cached << ":" << stat.num_disk_cached
  180. << "\n";
  181. // Expect there are 5000 rows there.
  182. EXPECT_EQ(5000, stat.max_row_id - stat.min_row_id + 1);
  183. // Get them all back using row id and compare with tensor t.
  184. for (auto i = stat.min_row_id; i <= stat.max_row_id; ++i) {
  185. tbl.clear();
  186. row.clear();
  187. rc = myClient->GetRows({i}, &tbl);
  188. ASSERT_TRUE(rc.IsOk());
  189. row = tbl.front();
  190. auto r = row.front();
  191. bool cmp = (*t == *r);
  192. ASSERT_TRUE(cmp);
  193. }
  194. rc = myClient->DestroyCache();
  195. ASSERT_TRUE(rc.IsOk());
  196. }
  197. // Simple test with a repeated cache op over random data producer
  198. //
  199. // RepeatOp
  200. // |
  201. // CacheOp
  202. // |
  203. // RandomDataOp
  204. //
  205. TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) {
  206. // Clear the rc of the master thread if any
  207. (void)TaskManager::GetMasterThreadRc();
  208. Status rc;
  209. int32_t rank = 0; // not used
  210. session_id_type env_session;
  211. rc = GetSessionFromEnv(&env_session);
  212. ASSERT_TRUE(rc.IsOk());
  213. MS_LOG(INFO) << "UT test TestRandomDataCache1";
  214. // Start with an empty execution tree
  215. auto myTree = std::make_shared<ExecutionTree>();
  216. // Create a schema using the C api's
  217. std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
  218. // 2 columns. First column is an "image" 640,480,3
  219. TensorShape c1Shape({640, 480, 3});
  220. ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
  221. rank, // not used
  222. &c1Shape);
  223. // Column 2 will just be a scalar label number
  224. TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
  225. ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
  226. testSchema->AddColumn(c1);
  227. testSchema->AddColumn(c2);
  228. // RandomDataOp
  229. std::shared_ptr<RandomDataOp> myRandomDataOp;
  230. rc = RandomDataOp::Builder()
  231. .SetRowsPerBuffer(4)
  232. .SetNumWorkers(4)
  233. .SetDataSchema(std::move(testSchema))
  234. .SetTotalRows(50) // 50 samples for now
  235. .Build(&myRandomDataOp);
  236. ASSERT_TRUE(rc.IsOk());
  237. rc = myTree->AssociateNode(myRandomDataOp);
  238. ASSERT_TRUE(rc.IsOk());
  239. // CacheOp
  240. // size of 0, spilling is true
  241. CacheClient::Builder builder;
  242. builder.SetSessionId(env_session).SetCacheMemSz(0).SetSpill(true);
  243. std::shared_ptr<CacheClient> myClient;
  244. rc = builder.Build(&myClient);
  245. ASSERT_TRUE(rc.IsOk());
  246. std::shared_ptr<CacheOp> myCacheOp;
  247. int64_t num_samples = 0;
  248. int64_t start_index = 0;
  249. auto seq_sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
  250. rc = CacheOp::Builder()
  251. .SetNumWorkers(5)
  252. .SetClient(myClient)
  253. .SetRowsPerBuffer(4)
  254. .SetSampler(std::move(seq_sampler))
  255. .Build(&myCacheOp);
  256. ASSERT_TRUE(rc.IsOk());
  257. rc = myTree->AssociateNode(myCacheOp);
  258. ASSERT_TRUE(rc.IsOk());
  259. // RepeatOp
  260. uint32_t numRepeats = 4;
  261. std::shared_ptr<RepeatOp> myRepeatOp;
  262. rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
  263. ASSERT_TRUE(rc.IsOk());
  264. rc = myTree->AssociateNode(myRepeatOp);
  265. ASSERT_TRUE(rc.IsOk());
  266. // Assign tree relations and root
  267. rc = myRepeatOp->AddChild(myCacheOp);
  268. ASSERT_TRUE(rc.IsOk());
  269. rc = myCacheOp->AddChild(myRandomDataOp);
  270. ASSERT_TRUE(rc.IsOk());
  271. rc = myTree->AssignRoot(myRepeatOp);
  272. ASSERT_TRUE(rc.IsOk());
  273. MS_LOG(INFO) << "Launching tree and begin iteration";
  274. rc = myTree->Prepare(1);
  275. ASSERT_TRUE(rc.IsOk());
  276. // quick check to see what tree looks like
  277. std::ostringstream ss;
  278. ss << *myTree; // some funny const error if I try to write directly to ms log stream
  279. MS_LOG(INFO) << "Here's the tree:\n" << ss.str();
  280. std::cout << *myClient << std::endl;
  281. rc = myTree->Launch();
  282. ASSERT_TRUE(rc.IsOk());
  283. // Start the loop of reading tensors from our pipeline
  284. DatasetIterator dI(myTree);
  285. TensorRow tensorList;
  286. rc = dI.FetchNextTensorRow(&tensorList);
  287. ASSERT_TRUE(rc.IsOk());
  288. int rowCount = 0;
  289. while (!tensorList.empty()) {
  290. // Don't display these rows, just count them
  291. MS_LOG(INFO) << "Row fetched #: " << rowCount;
  292. rc = dI.FetchNextTensorRow(&tensorList);
  293. ASSERT_TRUE(rc.IsOk());
  294. rowCount++;
  295. }
  296. ASSERT_EQ(rowCount, 200);
  297. rc = myClient->DestroyCache();
  298. ASSERT_TRUE(rc.IsOk());
  299. }
  300. //// Simple test with a repeated cache op over random data producer.
  301. //// This one will exceed memory and require a spill.
  302. ////
  303. //// RepeatOp
  304. //// |
  305. //// CacheOp
  306. //// |
  307. //// RandomDataOp
  308. ////
  309. TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) {
  310. // Clear the rc of the master thread if any
  311. (void)TaskManager::GetMasterThreadRc();
  312. Status rc;
  313. int32_t rank = 0; // not used
  314. MS_LOG(INFO) << "UT test TestRandomDataCacheSpill";
  315. session_id_type env_session;
  316. rc = GetSessionFromEnv(&env_session);
  317. ASSERT_TRUE(rc.IsOk());
  318. // Start with an empty execution tree
  319. auto myTree = std::make_shared<ExecutionTree>();
  320. // Create a schema using the C api's
  321. std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
  322. // 2 columns. First column is an "image" 640,480,3
  323. TensorShape c1Shape({640, 480, 3});
  324. ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
  325. rank, // not used
  326. &c1Shape);
  327. // Column 2 will just be a scalar label number
  328. TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
  329. ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
  330. testSchema->AddColumn(c1);
  331. testSchema->AddColumn(c2);
  332. // RandomDataOp
  333. std::shared_ptr<RandomDataOp> myRandomDataOp;
  334. rc = RandomDataOp::Builder()
  335. .SetRowsPerBuffer(2)
  336. .SetNumWorkers(4)
  337. .SetDataSchema(std::move(testSchema))
  338. .SetTotalRows(10)
  339. .Build(&myRandomDataOp);
  340. ASSERT_TRUE(rc.IsOk());
  341. rc = myTree->AssociateNode(myRandomDataOp);
  342. ASSERT_TRUE(rc.IsOk());
  343. // CacheOp
  344. int64_t num_samples = 0;
  345. int64_t start_index = 0;
  346. auto seq_sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
  347. CacheClient::Builder builder;
  348. builder.SetSessionId(env_session).SetCacheMemSz(4).SetSpill(true);
  349. std::shared_ptr<CacheClient> myClient;
  350. rc = builder.Build(&myClient);
  351. ASSERT_TRUE(rc.IsOk());
  352. std::shared_ptr<CacheOp> myCacheOp;
  353. rc = CacheOp::Builder()
  354. .SetNumWorkers(4)
  355. .SetClient(myClient)
  356. .SetRowsPerBuffer(3)
  357. .SetSampler(std::move(seq_sampler))
  358. .Build(&myCacheOp);
  359. ASSERT_TRUE(rc.IsOk());
  360. rc = myTree->AssociateNode(myCacheOp);
  361. ASSERT_TRUE(rc.IsOk());
  362. // RepeatOp
  363. uint32_t numRepeats = 4;
  364. std::shared_ptr<RepeatOp> myRepeatOp;
  365. rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
  366. ASSERT_TRUE(rc.IsOk());
  367. rc = myTree->AssociateNode(myRepeatOp);
  368. ASSERT_TRUE(rc.IsOk());
  369. // Assign tree relations and root
  370. rc = myRepeatOp->AddChild(myCacheOp);
  371. ASSERT_TRUE(rc.IsOk());
  372. rc = myCacheOp->AddChild(myRandomDataOp);
  373. ASSERT_TRUE(rc.IsOk());
  374. rc = myTree->AssignRoot(myRepeatOp);
  375. ASSERT_TRUE(rc.IsOk());
  376. MS_LOG(INFO) << "Launching tree and begin iteration";
  377. rc = myTree->Prepare(1);
  378. ASSERT_TRUE(rc.IsOk());
  379. std::cout << *myClient << std::endl;
  380. rc = myTree->Launch();
  381. ASSERT_TRUE(rc.IsOk());
  382. // Start the loop of reading tensors from our pipeline
  383. DatasetIterator dI(myTree);
  384. TensorRow tensorList;
  385. rc = dI.FetchNextTensorRow(&tensorList);
  386. ASSERT_TRUE(rc.IsOk());
  387. int rowCount = 0;
  388. while (!tensorList.empty()) {
  389. // Don't display these rows, just count them
  390. MS_LOG(INFO) << "Row fetched #: " << rowCount;
  391. rc = dI.FetchNextTensorRow(&tensorList);
  392. ASSERT_TRUE(rc.IsOk());
  393. rowCount++;
  394. }
  395. ASSERT_EQ(rowCount, 40);
  396. rc = myClient->DestroyCache();
  397. ASSERT_TRUE(rc.IsOk());
  398. }
  399. TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
  400. // Clear the rc of the master thread if any
  401. (void)TaskManager::GetMasterThreadRc();
  402. Status rc;
  403. int64_t num_samples = 0;
  404. int64_t start_index = 0;
  405. session_id_type env_session;
  406. rc = GetSessionFromEnv(&env_session);
  407. ASSERT_TRUE(rc.IsOk());
  408. auto seq_sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
  409. CacheClient::Builder ccbuilder;
  410. ccbuilder.SetSessionId(env_session).SetCacheMemSz(0).SetSpill(true);
  411. std::shared_ptr<CacheClient> myClient;
  412. rc = ccbuilder.Build(&myClient);
  413. ASSERT_TRUE(rc.IsOk());
  414. // In a mappable dataset, it uses a complex interactions of cache lookup op and cache merge op.
  415. // Rather than manually build this, the way to do it is to choose the position of the cache in the tree by
  416. // adding a CacheOp. Then, the tree prepare code will drive a transform that will remove the CacheOp and
  417. // replace it with the required tree structures for cache lookup op and cache merge op.
  418. std::shared_ptr<CacheOp> myCacheOp;
  419. rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp);
  420. std::shared_ptr<ImageFolderOp> so;
  421. ImageFolderOp::Builder builder;
  422. builder.SetSampler(std::move(seq_sampler))
  423. .SetOpConnectorSize(3)
  424. .SetNumWorkers(3)
  425. .SetRowsPerBuffer(2)
  426. .SetExtensions({".jpg", ".JPEG"})
  427. .SetRecursive(true)
  428. .SetImageFolderDir(datasets_root_path_ + "/testPK/data");
  429. rc = builder.Build(&so);
  430. ASSERT_TRUE(rc.IsOk());
  431. // RepeatOp
  432. uint32_t numRepeats = 4;
  433. std::shared_ptr<RepeatOp> myRepeatOp;
  434. rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
  435. ASSERT_TRUE(rc.IsOk());
  436. auto myTree = std::make_shared<ExecutionTree>();
  437. rc = myTree->AssociateNode(so);
  438. ASSERT_TRUE(rc.IsOk());
  439. rc = myTree->AssociateNode(myCacheOp);
  440. ASSERT_TRUE(rc.IsOk());
  441. rc = myTree->AssociateNode(myRepeatOp);
  442. ASSERT_TRUE(rc.IsOk());
  443. rc = myTree->AssignRoot(myRepeatOp);
  444. ASSERT_TRUE(rc.IsOk());
  445. rc = myRepeatOp->AddChild(myCacheOp);
  446. ASSERT_TRUE(rc.IsOk());
  447. rc = myCacheOp->AddChild(so);
  448. ASSERT_TRUE(rc.IsOk());
  449. rc = myTree->Prepare(1);
  450. ASSERT_TRUE(rc.IsOk());
  451. rc = myTree->Launch();
  452. ASSERT_TRUE(rc.IsOk());
  453. // Start the loop of reading tensors from our pipeline
  454. DatasetIterator dI(myTree);
  455. TensorRow tensorList;
  456. rc = dI.FetchNextTensorRow(&tensorList);
  457. ASSERT_TRUE(rc.IsOk());
  458. int rowCount = 0;
  459. while (!tensorList.empty()) {
  460. rc = dI.FetchNextTensorRow(&tensorList);
  461. ASSERT_TRUE(rc.IsOk());
  462. if (rc.IsError()) {
  463. std::cout << rc << std::endl;
  464. break;
  465. }
  466. rowCount++;
  467. }
  468. ASSERT_EQ(rowCount, 176);
  469. std::cout << "Row count : " << rowCount << std::endl;
  470. rc = myClient->DestroyCache();
  471. ASSERT_TRUE(rc.IsOk());
  472. }
  473. //// Simple test with a repeated cache op over random data producer.
  474. //// The difference in this one is that you do not add the sampler to the cache op directly.
  475. //// Instead, the sampler is added as part of the leaf op construction. Then, the prepare
  476. //// phase will pull this up from the leaf and into the cache.
  477. //// It removes the sampler from the leaf op, which doesn't make sense there anyway for
  478. //// the RandomDataOp which doesn't support sampling without a cache.
  479. ////
  480. //// RepeatOp
  481. //// |
  482. //// CacheOp
  483. //// |
  484. //// RandomDataOp
  485. ////
  486. TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) {
  487. // Clear the rc of the master thread if any
  488. (void)TaskManager::GetMasterThreadRc();
  489. Status rc;
  490. int32_t rank = 0; // not used
  491. MS_LOG(INFO) << "UT test TestCacheInheritSampler";
  492. session_id_type env_session;
  493. rc = GetSessionFromEnv(&env_session);
  494. ASSERT_TRUE(rc.IsOk());
  495. int64_t num_samples = 0;
  496. int64_t start_index = 0;
  497. auto seq_sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
  498. // Start with an empty execution tree
  499. auto myTree = std::make_shared<ExecutionTree>();
  500. // Create a schema using the C api's
  501. std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
  502. // 2 columns. First column is an "image" 640,480,3
  503. TensorShape c1Shape({640, 480, 3});
  504. ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
  505. rank, // not used
  506. &c1Shape);
  507. // Column 2 will just be a scalar label number
  508. TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
  509. ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
  510. testSchema->AddColumn(c1);
  511. testSchema->AddColumn(c2);
  512. // RandomDataOp
  513. std::shared_ptr<RandomDataOp> myRandomDataOp;
  514. rc = RandomDataOp::Builder()
  515. .SetRowsPerBuffer(2)
  516. .SetNumWorkers(4)
  517. .SetDataSchema(std::move(testSchema))
  518. .SetTotalRows(10)
  519. .SetSampler(std::move(seq_sampler))
  520. .Build(&myRandomDataOp);
  521. ASSERT_TRUE(rc.IsOk());
  522. rc = myTree->AssociateNode(myRandomDataOp);
  523. ASSERT_TRUE(rc.IsOk());
  524. // CacheOp
  525. CacheClient::Builder ccbuilder;
  526. // use arbitrary session of 1, size of 0, spilling// is true
  527. ccbuilder.SetSessionId(env_session).SetCacheMemSz(4).SetSpill(true);
  528. std::shared_ptr<CacheClient> myClient;
  529. rc = ccbuilder.Build(&myClient);
  530. ASSERT_TRUE(rc.IsOk());
  531. std::shared_ptr<CacheOp> myCacheOp;
  532. rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp);
  533. ASSERT_TRUE(rc.IsOk());
  534. rc = myTree->AssociateNode(myCacheOp);
  535. ASSERT_TRUE(rc.IsOk());
  536. // RepeatOp
  537. uint32_t numRepeats = 4;
  538. std::shared_ptr<RepeatOp> myRepeatOp;
  539. rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
  540. ASSERT_TRUE(rc.IsOk());
  541. rc = myTree->AssociateNode(myRepeatOp);
  542. ASSERT_TRUE(rc.IsOk());
  543. // Assign tree relations and root
  544. rc = myRepeatOp->AddChild(myCacheOp);
  545. ASSERT_TRUE(rc.IsOk());
  546. rc = myCacheOp->AddChild(myRandomDataOp);
  547. ASSERT_TRUE(rc.IsOk());
  548. rc = myTree->AssignRoot(myRepeatOp);
  549. ASSERT_TRUE(rc.IsOk());
  550. MS_LOG(INFO) << "Launching tree and begin iteration";
  551. rc = myTree->Prepare(1);
  552. ASSERT_TRUE(rc.IsOk());
  553. std::cout << *myClient << std::endl;
  554. rc = myTree->Launch();
  555. ASSERT_TRUE(rc.IsOk());
  556. // Start the loop of reading tensors from our pipeline
  557. DatasetIterator dI(myTree);
  558. TensorRow tensorList;
  559. rc = dI.FetchNextTensorRow(&tensorList);
  560. ASSERT_TRUE(rc.IsOk());
  561. int rowCount = 0;
  562. while (!tensorList.empty()) {
  563. // Don't display these rows, just count them
  564. MS_LOG(INFO) << "Row fetched #: " << rowCount;
  565. rc = dI.FetchNextTensorRow(&tensorList);
  566. ASSERT_TRUE(rc.IsOk());
  567. rowCount++;
  568. }
  569. ASSERT_EQ(rowCount, 40);
  570. rc = myClient->DestroyCache();
  571. ASSERT_TRUE(rc.IsOk());
  572. }