You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cache_op_test.cc 11 kB

4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <string>
  17. #include "minddata/dataset/core/client.h"
  18. #include "minddata/dataset/engine/cache/cache_client.h"
  19. #include "minddata/dataset/engine/execution_tree.h"
  20. #include "minddata/dataset/engine/datasetops/cache_op.h"
  21. #include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
  22. #include "minddata/dataset/engine/datasetops/cache_merge_op.h"
  23. #include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
  24. #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
  25. #include "minddata/dataset/engine/jagged_connector.h"
  26. #include "common/common.h"
  27. #include "gtest/gtest.h"
  28. #include "utils/log_adapter.h"
  29. #include "minddata/dataset/engine/datasetops/source/random_data_op.h"
  30. #include "minddata/dataset/engine/data_schema.h"
  31. using namespace mindspore::dataset;
  32. using mindspore::LogStream;
  33. using mindspore::dataset::CacheClient;
  34. using mindspore::dataset::TaskGroup;
  35. using mindspore::ExceptionType::NoExceptionType;
  36. using mindspore::MsLogLevel::INFO;
  37. // Helper function to get the session id from SESSION_ID env variable
  38. Status GetSessionFromEnv(session_id_type *session_id) {
  39. RETURN_UNEXPECTED_IF_NULL(session_id);
  40. if (const char *session_env = std::getenv("SESSION_ID")) {
  41. std::string session_id_str(session_env);
  42. try {
  43. *session_id = std::stoul(session_id_str);
  44. } catch (const std::exception &e) {
  45. std::string err_msg = "Invalid numeric value for session id in env var: " + session_id_str;
  46. return Status(StatusCode::kMDSyntaxError, err_msg);
  47. }
  48. } else {
  49. RETURN_STATUS_UNEXPECTED("Test case requires a session id to be provided via SESSION_ID environment variable.");
  50. }
  51. return Status::OK();
  52. }
  53. class MindDataTestCacheOp : public UT::DatasetOpTesting {
  54. public:
  55. void SetUp() override {
  56. DatasetOpTesting::SetUp();
  57. GlobalInit();
  58. }
  59. };
  60. TEST_F(MindDataTestCacheOp, DISABLED_TestCacheServer) {
  61. Status rc;
  62. CacheClient::Builder builder;
  63. session_id_type env_session;
  64. rc = GetSessionFromEnv(&env_session);
  65. ASSERT_TRUE(rc.IsOk());
  66. // use arbitrary session of 1, size of 0, spilling// is true
  67. builder.SetSessionId(env_session).SetCacheMemSz(0).SetSpill(true);
  68. std::shared_ptr<CacheClient> myClient;
  69. rc = builder.Build(&myClient);
  70. ASSERT_TRUE(rc.IsOk());
  71. // cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated.
  72. rc = myClient->CreateCache(1, true);
  73. ASSERT_TRUE(rc.IsOk());
  74. std::cout << *myClient << std::endl;
  75. // Create a schema using the C api's
  76. int32_t rank = 0; // not used
  77. std::unique_ptr<DataSchema> test_schema = std::make_unique<DataSchema>();
  78. // 2 columns. First column is an "image" 640,480,3
  79. TensorShape c1Shape({640, 480, 3});
  80. ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
  81. rank, // not used
  82. &c1Shape);
  83. // Column 2 will just be a scalar label number
  84. TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
  85. ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
  86. test_schema->AddColumn(c1);
  87. test_schema->AddColumn(c2);
  88. std::unordered_map<std::string, int32_t> map;
  89. rc = test_schema->GetColumnNameMap(&map);
  90. ASSERT_TRUE(rc.IsOk());
  91. // Test the CacheSchema api
  92. rc = myClient->CacheSchema(map);
  93. ASSERT_TRUE(rc.IsOk());
  94. // Create a tensor, take a snapshot and restore it back, and compare.
  95. std::shared_ptr<Tensor> t;
  96. Tensor::CreateEmpty(TensorShape({2, 3}), DataType(DataType::DE_UINT64), &t);
  97. t->SetItemAt<uint64_t>({0, 0}, 1);
  98. t->SetItemAt<uint64_t>({0, 1}, 2);
  99. t->SetItemAt<uint64_t>({0, 2}, 3);
  100. t->SetItemAt<uint64_t>({1, 0}, 4);
  101. t->SetItemAt<uint64_t>({1, 1}, 5);
  102. t->SetItemAt<uint64_t>({1, 2}, 6);
  103. std::cout << *t << std::endl;
  104. TensorTable tbl;
  105. TensorRow row;
  106. row.push_back(t);
  107. int64_t row_id;
  108. rc = myClient->WriteRow(row, &row_id);
  109. ASSERT_TRUE(rc.IsOk());
  110. // Switch off build phase.
  111. rc = myClient->BuildPhaseDone();
  112. ASSERT_TRUE(rc.IsOk());
  113. // Now restore from cache.
  114. row.clear();
  115. rc = myClient->GetRows({row_id}, &tbl);
  116. row = tbl.front();
  117. ASSERT_TRUE(rc.IsOk());
  118. auto r = row.front();
  119. std::cout << *r << std::endl;
  120. // Compare
  121. bool cmp = (*t == *r);
  122. ASSERT_TRUE(cmp);
  123. // Get back the schema and verify
  124. std::unordered_map<std::string, int32_t> map_out;
  125. rc = myClient->FetchSchema(&map_out);
  126. ASSERT_TRUE(rc.IsOk());
  127. cmp = (map_out == map);
  128. ASSERT_TRUE(cmp);
  129. rc = myClient->DestroyCache();
  130. ASSERT_TRUE(rc.IsOk());
  131. }
  132. TEST_F(MindDataTestCacheOp, DISABLED_TestConcurrencyRequest) {
  133. // Clear the rc of the master thread if any
  134. (void)TaskManager::GetMasterThreadRc();
  135. TaskGroup vg;
  136. Status rc;
  137. session_id_type env_session;
  138. rc = GetSessionFromEnv(&env_session);
  139. ASSERT_TRUE(rc.IsOk());
  140. // use arbitrary session of 1, size 1, spilling is true
  141. CacheClient::Builder builder;
  142. // use arbitrary session of 1, size of 0, spilling// is true
  143. builder.SetSessionId(env_session).SetCacheMemSz(1).SetSpill(true);
  144. std::shared_ptr<CacheClient> myClient;
  145. rc = builder.Build(&myClient);
  146. ASSERT_TRUE(rc.IsOk());
  147. // cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated.
  148. rc = myClient->CreateCache(1, true);
  149. ASSERT_TRUE(rc.IsOk());
  150. std::cout << *myClient << std::endl;
  151. std::shared_ptr<Tensor> t;
  152. Tensor::CreateEmpty(TensorShape({2, 3}), DataType(DataType::DE_UINT64), &t);
  153. t->SetItemAt<uint64_t>({0, 0}, 1);
  154. t->SetItemAt<uint64_t>({0, 1}, 2);
  155. t->SetItemAt<uint64_t>({0, 2}, 3);
  156. t->SetItemAt<uint64_t>({1, 0}, 4);
  157. t->SetItemAt<uint64_t>({1, 1}, 5);
  158. t->SetItemAt<uint64_t>({1, 2}, 6);
  159. TensorTable tbl;
  160. TensorRow row;
  161. row.push_back(t);
  162. // Cache tensor row t 5000 times using 10 threads.
  163. for (auto k = 0; k < 10; ++k) {
  164. Status vg_rc = vg.CreateAsyncTask("Test agent", [&myClient, &row]() -> Status {
  165. TaskManager::FindMe()->Post();
  166. for (auto i = 0; i < 500; i++) {
  167. RETURN_IF_NOT_OK(myClient->WriteRow(row));
  168. }
  169. return Status::OK();
  170. });
  171. ASSERT_TRUE(vg_rc.IsOk());
  172. }
  173. ASSERT_TRUE(vg.join_all().IsOk());
  174. ASSERT_TRUE(vg.GetTaskErrorIfAny().IsOk());
  175. rc = myClient->BuildPhaseDone();
  176. ASSERT_TRUE(rc.IsOk());
  177. // Get statistics from the server.
  178. CacheServiceStat stat{};
  179. rc = myClient->GetStat(&stat);
  180. ASSERT_TRUE(rc.IsOk());
  181. std::cout << stat.min_row_id << ":" << stat.max_row_id << ":" << stat.num_mem_cached << ":" << stat.num_disk_cached
  182. << "\n";
  183. // Expect there are 5000 rows there.
  184. EXPECT_EQ(5000, stat.max_row_id - stat.min_row_id + 1);
  185. // Get them all back using row id and compare with tensor t.
  186. for (auto i = stat.min_row_id; i <= stat.max_row_id; ++i) {
  187. tbl.clear();
  188. row.clear();
  189. rc = myClient->GetRows({i}, &tbl);
  190. ASSERT_TRUE(rc.IsOk());
  191. row = tbl.front();
  192. auto r = row.front();
  193. bool cmp = (*t == *r);
  194. ASSERT_TRUE(cmp);
  195. }
  196. rc = myClient->DestroyCache();
  197. ASSERT_TRUE(rc.IsOk());
  198. }
  199. TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
  200. // Clear the rc of the master thread if any
  201. (void)TaskManager::GetMasterThreadRc();
  202. Status rc;
  203. int64_t num_samples = 0;
  204. int64_t start_index = 0;
  205. session_id_type env_session;
  206. rc = GetSessionFromEnv(&env_session);
  207. ASSERT_TRUE(rc.IsOk());
  208. auto seq_sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
  209. CacheClient::Builder ccbuilder;
  210. ccbuilder.SetSessionId(env_session).SetCacheMemSz(0).SetSpill(true);
  211. std::shared_ptr<CacheClient> myClient;
  212. rc = ccbuilder.Build(&myClient);
  213. ASSERT_TRUE(rc.IsOk());
  214. std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
  215. int32_t op_connector_size = config_manager->op_connector_size();
  216. std::shared_ptr<CacheLookupOp> myLookupOp =
  217. std::make_shared<CacheLookupOp>(4, op_connector_size, myClient, std::move(seq_sampler));
  218. ASSERT_NE(myLookupOp, nullptr);
  219. std::shared_ptr<CacheMergeOp> myMergeOp = std::make_shared<CacheMergeOp>(4, op_connector_size, 4, myClient);
  220. ASSERT_NE(myMergeOp, nullptr);
  221. std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
  222. TensorShape scalar = TensorShape::CreateScalar();
  223. rc = schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1));
  224. ASSERT_TRUE(rc.IsOk());
  225. rc = schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar));
  226. ASSERT_TRUE(rc.IsOk());
  227. std::string dataset_path = datasets_root_path_ + "/testPK/data";
  228. std::set<std::string> ext = {".jpg", ".JPEG"};
  229. bool recursive = true;
  230. bool decode = false;
  231. std::map<std::string, int32_t> columns_to_load = {};
  232. std::shared_ptr<ImageFolderOp> so = std::make_shared<ImageFolderOp>(
  233. 3, dataset_path, 3, recursive, decode, ext, columns_to_load, std::move(schema), nullptr);
  234. so->SetSampler(myLookupOp);
  235. ASSERT_TRUE(rc.IsOk());
  236. // RepeatOp
  237. uint32_t num_repeats = 4;
  238. std::shared_ptr<RepeatOp> myRepeatOp = std::make_shared<RepeatOp>(num_repeats);
  239. auto myTree = std::make_shared<ExecutionTree>();
  240. rc = myTree->AssociateNode(so);
  241. ASSERT_TRUE(rc.IsOk());
  242. rc = myTree->AssociateNode(myLookupOp);
  243. ASSERT_TRUE(rc.IsOk());
  244. rc = myTree->AssociateNode(myMergeOp);
  245. ASSERT_TRUE(rc.IsOk());
  246. rc = myTree->AssociateNode(myRepeatOp);
  247. ASSERT_TRUE(rc.IsOk());
  248. rc = myTree->AssignRoot(myRepeatOp);
  249. ASSERT_TRUE(rc.IsOk());
  250. myMergeOp->SetTotalRepeats(num_repeats);
  251. myMergeOp->SetNumRepeatsPerEpoch(num_repeats);
  252. rc = myRepeatOp->AddChild(myMergeOp);
  253. ASSERT_TRUE(rc.IsOk());
  254. myLookupOp->SetTotalRepeats(num_repeats);
  255. myLookupOp->SetNumRepeatsPerEpoch(num_repeats);
  256. rc = myMergeOp->AddChild(myLookupOp);
  257. ASSERT_TRUE(rc.IsOk());
  258. so->SetTotalRepeats(num_repeats);
  259. so->SetNumRepeatsPerEpoch(num_repeats);
  260. rc = myMergeOp->AddChild(so);
  261. ASSERT_TRUE(rc.IsOk());
  262. rc = myTree->Prepare();
  263. ASSERT_TRUE(rc.IsOk());
  264. rc = myTree->Launch();
  265. ASSERT_TRUE(rc.IsOk());
  266. // Start the loop of reading tensors from our pipeline
  267. DatasetIterator dI(myTree);
  268. TensorRow tensorList;
  269. rc = dI.FetchNextTensorRow(&tensorList);
  270. ASSERT_TRUE(rc.IsOk());
  271. int rowCount = 0;
  272. while (!tensorList.empty()) {
  273. rc = dI.FetchNextTensorRow(&tensorList);
  274. ASSERT_TRUE(rc.IsOk());
  275. if (rc.IsError()) {
  276. std::cout << rc << std::endl;
  277. break;
  278. }
  279. rowCount++;
  280. }
  281. ASSERT_EQ(rowCount, 176);
  282. std::cout << "Row count : " << rowCount << std::endl;
  283. rc = myClient->DestroyCache();
  284. ASSERT_TRUE(rc.IsOk());
  285. }