You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_zero_copy.cc 8.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <sys/stat.h>
  17. #include <string>
  18. #include <vector>
  19. #include <fstream>
  20. #include <iostream>
  21. #include <sys/time.h>
  22. #include "common/common_test.h"
  23. #include "include/api/types.h"
  24. #include "minddata/dataset/include/dataset/execute.h"
  25. #include "minddata/dataset/include/dataset/transforms.h"
  26. #include "minddata/dataset/include/dataset/vision.h"
  27. #ifdef ENABLE_ACL
  28. #include "minddata/dataset/include/dataset/vision_ascend.h"
  29. #endif
  30. #include "minddata/dataset/kernels/tensor_op.h"
  31. #include "include/api/model.h"
  32. #include "include/api/serialization.h"
  33. #include "include/api/context.h"
  34. using namespace mindspore;
  35. using namespace mindspore::dataset;
  36. using namespace mindspore::dataset::vision;
  37. class TestZeroCopy : public ST::Common {
  38. public:
  39. TestZeroCopy() {}
  40. };
  41. typedef timeval TimeValue;
  42. constexpr auto resnet_file = "/home/workspace/mindspore_dataset/mindir/resnet50/resnet50_imagenet.mindir";
  43. constexpr auto image_path = "/home/workspace/mindspore_dataset/imagenet/imagenet_original/val/n01440764/";
  44. constexpr auto aipp_path = "./data/dataset/aipp_resnet50.cfg";
  45. constexpr uint64_t kUSecondInSecond = 1000000;
  46. constexpr uint64_t run_nums = 10;
  47. size_t GetMax(mindspore::MSTensor data);
  48. std::string RealPath(std::string_view path);
  49. DIR *OpenDir(std::string_view dir_name);
  50. std::vector<std::string> GetAllFiles(std::string_view dir_name);
  51. TEST_F(TestZeroCopy, TestMindIR) {
  52. #ifdef ENABLE_ACL
  53. // Set context
  54. auto context = ContextAutoSet();
  55. ASSERT_TRUE(context != nullptr);
  56. ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
  57. auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
  58. ASSERT_TRUE(ascend310_info != nullptr);
  59. ascend310_info->SetInsertOpConfigPath(aipp_path);
  60. auto device_id = ascend310_info->GetDeviceID();
  61. // Define model
  62. Graph graph;
  63. ASSERT_TRUE(Serialization::Load(resnet_file, ModelType::kMindIR, &graph) == kSuccess);
  64. Model resnet50;
  65. ASSERT_TRUE(resnet50.Build(GraphCell(graph), context) == kSuccess);
  66. // Get model info
  67. std::vector<mindspore::MSTensor> model_inputs = resnet50.GetInputs();
  68. ASSERT_EQ(model_inputs.size(), 1);
  69. // Define transform operations
  70. std::shared_ptr<TensorTransform> decode(new vision::Decode());
  71. std::shared_ptr<TensorTransform> resize(new vision::Resize({256}));
  72. std::shared_ptr<TensorTransform> center_crop(new vision::CenterCrop({224, 224}));
  73. mindspore::dataset::Execute Transform({decode, resize, center_crop}, MapTargetDevice::kAscend310, device_id);
  74. size_t count = 0;
  75. // Read images
  76. std::vector<std::string> images = GetAllFiles(image_path);
  77. for (const auto &image_file : images) {
  78. // prepare input
  79. std::vector<mindspore::MSTensor> inputs;
  80. std::vector<mindspore::MSTensor> outputs;
  81. std::shared_ptr<mindspore::dataset::Tensor> de_tensor;
  82. mindspore::dataset::Tensor::CreateFromFile(image_file, &de_tensor);
  83. auto image = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_tensor));
  84. // Apply transform on images
  85. Status rc = Transform(image, &image);
  86. ASSERT_TRUE(rc == kSuccess);
  87. inputs.push_back(image);
  88. // infer
  89. ASSERT_TRUE(resnet50.Predict(inputs, &outputs) == kSuccess);
  90. if (GetMax(outputs[0]) == 0) {
  91. ++count;
  92. }
  93. Transform.DeviceMemoryRelease();
  94. }
  95. ASSERT_GE(static_cast<double>(count) / images.size() * 100.0, 20.0);
  96. #endif
  97. }
  98. TEST_F(TestZeroCopy, TestDeviceTensor) {
  99. #ifdef ENABLE_ACL
  100. // Set context
  101. auto context = ContextAutoSet();
  102. ASSERT_TRUE(context != nullptr);
  103. ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
  104. auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
  105. ASSERT_TRUE(ascend310_info != nullptr);
  106. ascend310_info->SetInsertOpConfigPath(aipp_path);
  107. auto device_id = ascend310_info->GetDeviceID();
  108. // Define model
  109. Graph graph;
  110. ASSERT_TRUE(Serialization::Load(resnet_file, ModelType::kMindIR, &graph) == kSuccess);
  111. Model resnet50;
  112. ASSERT_TRUE(resnet50.Build(GraphCell(graph), context) == kSuccess);
  113. // Get model info
  114. std::vector<mindspore::MSTensor> model_inputs = resnet50.GetInputs();
  115. ASSERT_EQ(model_inputs.size(), 1);
  116. // Define transform operations
  117. std::shared_ptr<TensorTransform> decode(new vision::Decode());
  118. std::shared_ptr<TensorTransform> resize(new vision::Resize({256}));
  119. std::shared_ptr<TensorTransform> center_crop(new vision::CenterCrop({224, 224}));
  120. mindspore::dataset::Execute Transform({decode, resize, center_crop}, MapTargetDevice::kAscend310, device_id);
  121. // Read images
  122. std::vector<std::string> images = GetAllFiles(image_path);
  123. uint64_t cost = 0, device_cost = 0;
  124. for (const auto &image_file : images) {
  125. // prepare input
  126. std::vector<mindspore::MSTensor> inputs;
  127. std::vector<mindspore::MSTensor> outputs;
  128. std::shared_ptr<mindspore::dataset::Tensor> de_tensor;
  129. mindspore::dataset::Tensor::CreateFromFile(image_file, &de_tensor);
  130. auto image = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_tensor));
  131. // Apply transform on images
  132. Status rc = Transform(image, &image);
  133. ASSERT_TRUE(rc == kSuccess);
  134. MSTensor *device_tensor =
  135. MSTensor::CreateDevTensor(image.Name(), image.DataType(), image.Shape(),
  136. image.MutableData(), image.DataSize());
  137. MSTensor *tensor =
  138. MSTensor::CreateTensor(image.Name(), image.DataType(), image.Shape(),
  139. image.Data().get(), image.DataSize());
  140. inputs.push_back(*tensor);
  141. // infer
  142. TimeValue start_time, end_time;
  143. (void)gettimeofday(&start_time, nullptr);
  144. for (size_t i = 0; i < run_nums; ++i) {
  145. ASSERT_TRUE(resnet50.Predict(inputs, &outputs) == kSuccess);
  146. }
  147. (void)gettimeofday(&end_time, nullptr);
  148. cost +=
  149. (kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec) + static_cast<uint64_t>(end_time.tv_usec)) -
  150. (kUSecondInSecond * static_cast<uint64_t>(start_time.tv_sec) + static_cast<uint64_t>(start_time.tv_usec));
  151. // clear inputs
  152. inputs.clear();
  153. start_time = (TimeValue){0};
  154. end_time = (TimeValue){0};
  155. inputs.push_back(*device_tensor);
  156. // infer with device tensor
  157. (void)gettimeofday(&start_time, nullptr);
  158. for (size_t i = 0; i < run_nums; ++i) {
  159. ASSERT_TRUE(resnet50.Predict(inputs, &outputs) == kSuccess);
  160. }
  161. (void)gettimeofday(&end_time, nullptr);
  162. device_cost +=
  163. (kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec) + static_cast<uint64_t>(end_time.tv_usec)) -
  164. (kUSecondInSecond * static_cast<uint64_t>(start_time.tv_sec) + static_cast<uint64_t>(start_time.tv_usec));
  165. Transform.DeviceMemoryRelease();
  166. }
  167. ASSERT_GE(cost, device_cost);
  168. #endif
  169. }
  170. size_t GetMax(mindspore::MSTensor data) {
  171. float max_value = -1;
  172. size_t max_idx = 0;
  173. const float *p = reinterpret_cast<const float *>(data.MutableData());
  174. for (size_t i = 0; i < data.DataSize() / sizeof(float); ++i) {
  175. if (p[i] > max_value) {
  176. max_value = p[i];
  177. max_idx = i;
  178. }
  179. }
  180. return max_idx;
  181. }
  182. std::string RealPath(std::string_view path) {
  183. char real_path_mem[PATH_MAX] = {0};
  184. char *real_path_ret = realpath(path.data(), real_path_mem);
  185. if (real_path_ret == nullptr) {
  186. return "";
  187. }
  188. return std::string(real_path_mem);
  189. }
  190. DIR *OpenDir(std::string_view dir_name) {
  191. // check the parameter !
  192. if (dir_name.empty()) {
  193. return nullptr;
  194. }
  195. std::string real_path = RealPath(dir_name);
  196. // check if dir_name is a valid dir
  197. struct stat s;
  198. lstat(real_path.c_str(), &s);
  199. if (!S_ISDIR(s.st_mode)) {
  200. return nullptr;
  201. }
  202. DIR *dir;
  203. dir = opendir(real_path.c_str());
  204. if (dir == nullptr) {
  205. return nullptr;
  206. }
  207. return dir;
  208. }
  209. std::vector<std::string> GetAllFiles(std::string_view dir_name) {
  210. struct dirent *filename;
  211. DIR *dir = OpenDir(dir_name);
  212. if (dir == nullptr) {
  213. return {};
  214. }
  215. /* read all the files in the dir ~ */
  216. std::vector<std::string> res;
  217. while ((filename = readdir(dir)) != nullptr) {
  218. std::string d_name = std::string(filename->d_name);
  219. // get rid of "." and ".."
  220. if (d_name == "." || d_name == ".." || filename->d_type != DT_REG) continue;
  221. res.emplace_back(std::string(dir_name) + "/" + filename->d_name);
  222. }
  223. std::sort(res.begin(), res.end());
  224. return res;
  225. }