You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_zero_copy.cc 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <sys/stat.h>
  17. #include <string>
  18. #include <vector>
  19. #include <fstream>
  20. #include <iostream>
  21. #include "common/common_test.h"
  22. #include "include/api/types.h"
  23. #include "minddata/dataset/include/execute.h"
  24. #include "minddata/dataset/include/transforms.h"
  25. #include "minddata/dataset/include/vision.h"
  26. #ifdef ENABLE_ACL
  27. #include "minddata/dataset/include/vision_ascend.h"
  28. #endif
  29. #include "minddata/dataset/kernels/tensor_op.h"
  30. #include "include/api/model.h"
  31. #include "include/api/serialization.h"
  32. #include "include/api/context.h"
  33. using namespace mindspore;
  34. using namespace mindspore::dataset;
  35. using namespace mindspore::dataset::vision;
  36. class TestZeroCopy : public ST::Common {
  37. public:
  38. TestZeroCopy() {}
  39. };
  40. constexpr auto resnet_file = "/home/workspace/mindspore_dataset/mindir/resnet50/resnet50_imagenet.mindir";
  41. constexpr auto image_path = "/home/workspace/mindspore_dataset/imagenet/imagenet_original/val/n01440764/";
  42. constexpr auto aipp_path = "./data/dataset/aipp_resnet50.cfg";
  43. size_t GetMax(mindspore::MSTensor data);
  44. std::string RealPath(std::string_view path);
  45. DIR *OpenDir(std::string_view dir_name);
  46. std::vector<std::string> GetAllFiles(std::string_view dir_name);
  47. TEST_F(TestZeroCopy, TestMindIR) {
  48. #ifdef ENABLE_ACL
  49. // Set context
  50. mindspore::GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310);
  51. mindspore::GlobalContext::SetGlobalDeviceID(0);
  52. auto model_context = std::make_shared<ModelContext>();
  53. ModelContext::SetInsertOpConfigPath(model_context,aipp_path);
  54. // Define model
  55. auto graph = mindspore::Serialization::LoadModel(resnet_file, mindspore::ModelType::kMindIR);
  56. mindspore::Model resnet50(mindspore::GraphCell(graph),model_context);
  57. // Build model
  58. ASSERT_TRUE(resnet50.Build() == kSuccess);
  59. // Get model info
  60. std::vector<mindspore::MSTensor> model_inputs =resnet50.GetInputs();
  61. ASSERT_EQ(model_inputs.size(), 1);
  62. // Define transform operations
  63. std::shared_ptr<TensorTransform> decode(new vision::Decode());
  64. std::shared_ptr<TensorTransform> resize(new vision::Resize({256}));
  65. std::shared_ptr<TensorTransform> center_crop(new vision::CenterCrop({224,224}));
  66. mindspore::dataset::Execute Transform({decode,resize,center_crop},MapTargetDevice::kAscend310);
  67. size_t count=0;
  68. // Read images
  69. std::vector<std::string> images =GetAllFiles(image_path);
  70. for(const auto &image_file:images){
  71. // prepare input
  72. std::vector<mindspore::MSTensor> inputs;
  73. std::vector<mindspore::MSTensor> outputs;
  74. std::shared_ptr<mindspore::dataset::Tensor> de_tensor;
  75. mindspore::dataset::Tensor::CreateFromFile(image_file, &de_tensor);
  76. auto image = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_tensor));
  77. // Apply transform on images
  78. Status rc = Transform(image, &image);
  79. ASSERT_TRUE(rc == kSuccess);
  80. inputs.push_back(image);
  81. // infer
  82. ASSERT_TRUE(resnet50.Predict(inputs, &outputs)==kSuccess);
  83. if(GetMax(outputs[0])==0){
  84. ++count;
  85. }
  86. Transform.DeviceMemoryRelease();
  87. }
  88. ASSERT_GE(static_cast<double>(count)/images.size()*100.0, 20.0);
  89. #endif
  90. }
  91. size_t GetMax(mindspore::MSTensor data) {
  92. float max_value = -1;
  93. size_t max_idx = 0;
  94. const float *p = reinterpret_cast<const float *>(data.MutableData());
  95. for (size_t i = 0; i < data.DataSize() / sizeof(float); ++i) {
  96. if (p[i] > max_value) {
  97. max_value = p[i];
  98. max_idx = i;
  99. }
  100. }
  101. return max_idx;
  102. }
  103. std::string RealPath(std::string_view path) {
  104. char real_path_mem[PATH_MAX] = {0};
  105. char *real_path_ret = realpath(path.data(), real_path_mem);
  106. if (real_path_ret == nullptr) {
  107. return "";
  108. }
  109. return std::string(real_path_mem);
  110. }
  111. DIR *OpenDir(std::string_view dir_name) {
  112. // check the parameter !
  113. if (dir_name.empty()) {
  114. return nullptr;
  115. }
  116. std::string real_path = RealPath(dir_name);
  117. // check if dir_name is a valid dir
  118. struct stat s;
  119. lstat(real_path.c_str(), &s);
  120. if (!S_ISDIR(s.st_mode)) {
  121. return nullptr;
  122. }
  123. DIR *dir;
  124. dir = opendir(real_path.c_str());
  125. if (dir == nullptr) {
  126. return nullptr;
  127. }
  128. return dir;
  129. }
  130. std::vector<std::string> GetAllFiles(std::string_view dir_name) {
  131. struct dirent *filename;
  132. DIR *dir = OpenDir(dir_name);
  133. if (dir == nullptr) {
  134. return {};
  135. }
  136. /* read all the files in the dir ~ */
  137. std::vector<std::string> res;
  138. while ((filename = readdir(dir)) != nullptr) {
  139. std::string d_name = std::string(filename->d_name);
  140. // get rid of "." and ".."
  141. if (d_name == "." || d_name == ".." || filename->d_type != DT_REG)
  142. continue;
  143. res.emplace_back(std::string(dir_name) + "/" + filename->d_name);
  144. }
  145. std::sort(res.begin(), res.end());
  146. return res;
  147. }