You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_zero_copy.cc 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <sys/stat.h>
  17. #include <string>
  18. #include <vector>
  19. #include <fstream>
  20. #include <iostream>
  21. #include "common/common_test.h"
  22. #include "include/api/types.h"
  23. #include "minddata/dataset/include/execute.h"
  24. #include "minddata/dataset/include/transforms.h"
  25. #include "minddata/dataset/include/vision.h"
  26. #ifdef ENABLE_ACL
  27. #include "minddata/dataset/include/vision_ascend.h"
  28. #endif
  29. #include "minddata/dataset/kernels/tensor_op.h"
  30. #include "include/api/model.h"
  31. #include "include/api/serialization.h"
  32. #include "include/api/context.h"
  33. using namespace mindspore;
  34. using namespace mindspore::dataset;
  35. using namespace mindspore::dataset::vision;
  36. class TestZeroCopy : public ST::Common {
  37. public:
  38. TestZeroCopy() {}
  39. };
  40. constexpr auto resnet_file = "/home/workspace/mindspore_dataset/mindir/resnet50/resnet50_imagenet.mindir";
  41. constexpr auto image_path = "/home/workspace/mindspore_dataset/imagenet/imagenet_original/val/n01440764/";
  42. constexpr auto aipp_path = "./data/dataset/aipp_resnet50.cfg";
  43. size_t GetMax(mindspore::MSTensor data);
  44. std::string RealPath(std::string_view path);
  45. DIR *OpenDir(std::string_view dir_name);
  46. std::vector<std::string> GetAllFiles(std::string_view dir_name);
  47. TEST_F(TestZeroCopy, TestMindIR) {
  48. #ifdef ENABLE_ACL
  49. // Set context
  50. auto context = ContextAutoSet();
  51. ASSERT_TRUE(context != nullptr);
  52. ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
  53. auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
  54. ASSERT_TRUE(ascend310_info != nullptr);
  55. ascend310_info->SetInsertOpConfigPath(aipp_path);
  56. auto device_id = ascend310_info->GetDeviceID();
  57. // Define model
  58. Graph graph;
  59. ASSERT_TRUE(Serialization::Load(resnet_file, ModelType::kMindIR, &graph) == kSuccess);
  60. Model resnet50;
  61. ASSERT_TRUE(resnet50.Build(GraphCell(graph), context) == kSuccess);
  62. // Get model info
  63. std::vector<mindspore::MSTensor> model_inputs = resnet50.GetInputs();
  64. ASSERT_EQ(model_inputs.size(), 1);
  65. // Define transform operations
  66. std::shared_ptr<TensorTransform> decode(new vision::Decode());
  67. std::shared_ptr<TensorTransform> resize(new vision::Resize({256}));
  68. std::shared_ptr<TensorTransform> center_crop(new vision::CenterCrop({224, 224}));
  69. mindspore::dataset::Execute Transform({decode, resize, center_crop}, MapTargetDevice::kAscend310, device_id);
  70. size_t count = 0;
  71. // Read images
  72. std::vector<std::string> images = GetAllFiles(image_path);
  73. for (const auto &image_file : images) {
  74. // prepare input
  75. std::vector<mindspore::MSTensor> inputs;
  76. std::vector<mindspore::MSTensor> outputs;
  77. std::shared_ptr<mindspore::dataset::Tensor> de_tensor;
  78. mindspore::dataset::Tensor::CreateFromFile(image_file, &de_tensor);
  79. auto image = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_tensor));
  80. // Apply transform on images
  81. Status rc = Transform(image, &image);
  82. ASSERT_TRUE(rc == kSuccess);
  83. inputs.push_back(image);
  84. // infer
  85. ASSERT_TRUE(resnet50.Predict(inputs, &outputs) == kSuccess);
  86. if (GetMax(outputs[0]) == 0) {
  87. ++count;
  88. }
  89. Transform.DeviceMemoryRelease();
  90. }
  91. ASSERT_GE(static_cast<double>(count) / images.size() * 100.0, 20.0);
  92. #endif
  93. }
  94. size_t GetMax(mindspore::MSTensor data) {
  95. float max_value = -1;
  96. size_t max_idx = 0;
  97. const float *p = reinterpret_cast<const float *>(data.MutableData());
  98. for (size_t i = 0; i < data.DataSize() / sizeof(float); ++i) {
  99. if (p[i] > max_value) {
  100. max_value = p[i];
  101. max_idx = i;
  102. }
  103. }
  104. return max_idx;
  105. }
  106. std::string RealPath(std::string_view path) {
  107. char real_path_mem[PATH_MAX] = {0};
  108. char *real_path_ret = realpath(path.data(), real_path_mem);
  109. if (real_path_ret == nullptr) {
  110. return "";
  111. }
  112. return std::string(real_path_mem);
  113. }
  114. DIR *OpenDir(std::string_view dir_name) {
  115. // check the parameter !
  116. if (dir_name.empty()) {
  117. return nullptr;
  118. }
  119. std::string real_path = RealPath(dir_name);
  120. // check if dir_name is a valid dir
  121. struct stat s;
  122. lstat(real_path.c_str(), &s);
  123. if (!S_ISDIR(s.st_mode)) {
  124. return nullptr;
  125. }
  126. DIR *dir;
  127. dir = opendir(real_path.c_str());
  128. if (dir == nullptr) {
  129. return nullptr;
  130. }
  131. return dir;
  132. }
  133. std::vector<std::string> GetAllFiles(std::string_view dir_name) {
  134. struct dirent *filename;
  135. DIR *dir = OpenDir(dir_name);
  136. if (dir == nullptr) {
  137. return {};
  138. }
  139. /* read all the files in the dir ~ */
  140. std::vector<std::string> res;
  141. while ((filename = readdir(dir)) != nullptr) {
  142. std::string d_name = std::string(filename->d_name);
  143. // get rid of "." and ".."
  144. if (d_name == "." || d_name == ".." || filename->d_type != DT_REG) continue;
  145. res.emplace_back(std::string(dir_name) + "/" + filename->d_name);
  146. }
  147. std::sort(res.begin(), res.end());
  148. return res;
  149. }